Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add learner, resampling and measure weights #1124

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Depends:
R (>= 3.1.0)
Imports:
R6 (>= 2.4.1),
backports,
backports (>= 1.5.0),
checkmate (>= 2.0.0),
data.table (>= 1.15.0),
evaluate,
Expand Down
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# mlr3 (development version)

* BREAKING CHANGE: `weights` property and functionality is split into `weights_learner`, `weights_measure`, and `weights_resampling`:

* `weights_learner`: Weights used during training by the Learner.
* `weights_measure`: Weights used during scoring predictions via measures.
* `weights_resampling`: Weights used during resampling to sample observations with unequal probability.

Each of these can be disabled via the new hyperparameter (Measure, Resampling) or field (Learner) `use_weights`.

# mlr3 0.22.1

* fix: Extend `assert_measure()` with checks for trained models in `assert_scorable()`.
Expand Down
60 changes: 58 additions & 2 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@
#' Only available for [`Learner`]s with the `"internal_tuning"` property.
#' If the learner is not trained yet, this returns `NULL`.
#'
#' @section Weights:
#'
#' Many learners support observation weights, indicated by their property `"weights"`.
#' The weights are stored in the [Task] where the column role `weights_learner` needs to be assigned to a single numeric column.
#' If a task has weights and the learner supports them, they are used automatically.
#' If a task has weights but the learner does not support them, an error is thrown.
#' Both of these behaviors can be disabled by setting the `use_weights` field to `"ignore"`.
#' See the description of `use_weights` for more information.
#'
#' If the learner is set-up to use weights but the task does not have a designated weight column, an unweighted version is calculated instead.
#' When they are being used, weights are passed down to the learner directly.
#' Generally, they do not necessarily need to sum up to 1.
#'
#' @section Setting Hyperparameters:
#'
#' All information about hyperparameters is stored in the slot `param_set` which is a [paradox::ParamSet].
Expand Down Expand Up @@ -212,7 +225,6 @@ Learner = R6Class("Learner",
self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
private$.param_set = assert_param_set(param_set)
self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types")
self$predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
Expand All @@ -222,6 +234,13 @@ Learner = R6Class("Learner",
self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L))
self$man = assert_string(man, na.ok = TRUE)

if ("weights" %in% self$properties) {
self$use_weights = "use"
} else {
self$use_weights = "error"
}
private$.param_set = param_set

check_packages_installed(packages, msg = sprintf("Package '%%s' required but not installed for Learner '%s'", id))
},

Expand Down Expand Up @@ -402,7 +421,7 @@ Learner = R6Class("Learner",
assert_names(newdata$colnames, must.include = task$feature_names)

# the following columns are automatically set to NA if missing
impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group", "weight")], use.names = FALSE)
impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group", "weights_learner", "weights_measure", "weights_resampling")], use.names = FALSE)
impute = setdiff(impute, newdata$colnames)
if (length(impute)) {
# create list with correct NA types and cbind it to the backend
Expand Down Expand Up @@ -509,6 +528,26 @@ Learner = R6Class("Learner",
),

active = list(
#' @field use_weights (`character(1)`)\cr
#' How to use weights.
#' Settings are `"use"` `"ignore"`, and `"error"`.
#'
#' * `"use"`: use weights, as supported by the underlying `Learner`.
#' * `"ignore"`: do not use weights.
#' * `"error"`: throw an error if weights are present in the training `Task`.
#'
#' For `Learner`s with the property `"weights_learner"`, this is initialized as `"use"`.
#' For `Learner`s that do not support weights, i.e. without the `"weights_learner"` property, this is initialized as `"error"`.
#' The latter behavior is to avoid cases where a user erroneously assumes that a `Learner` supports weights when it does not.
#' For `Learner`s that do not support weights, `use_weights` needs to be set to `"ignore"` if tasks with weights should be handled (by dropping the weights).
use_weights = function(rhs) {
if (!missing(rhs)) {
assert_choice(rhs, c(if ("weights" %in% self$properties) "use", "ignore", "error"))
private$.use_weights = rhs
}
private$.use_weights
},

#' @field data_formats (`character()`)\cr
#' Supported data format. Always `"data.table"`..
#' This is deprecated and will be removed in the future.
Expand Down Expand Up @@ -632,12 +671,29 @@ Learner = R6Class("Learner",
),

private = list(
.use_weights = NULL,
.encapsulation = c(train = "none", predict = "none"),
.fallback = NULL,
.predict_type = NULL,
.param_set = NULL,
.hotstart_stack = NULL,

# retrieve weights from a task, if it has weights and if the user did not
# deactivate weight usage through `self$use_weights`.
# - `task`: Task to retrieve weights from
# - `no_weights_val`: Value to return if no weights are found (default NULL)
# return: Numeric vector of weights or `no_weights_val` (default NULL)
.get_weights = function(task, no_weights_val = NULL) {
if ("weights" %nin% self$properties) {
stop("private$.get_weights should not be used in Learners that do not have the 'weights_learner' property.")
}
if (self$use_weights == "use" && "weights_learner" %in% task$properties) {
task$weights_learner$weight
} else {
no_weights_val
}
},

deep_clone = function(name, value) {
switch(name,
.param_set = value$clone(deep = TRUE),
Expand Down
8 changes: 2 additions & 6 deletions R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
minsplit = p_int(1L, default = 20L, tags = "train"),
surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"),
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
xval = p_int(0L, default = 10L, init = 0L, tags = "train")
)
ps$values = list(xval = 0L)

super$initialize(
id = "classif.rpart",
Expand Down Expand Up @@ -77,10 +76,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

pv$weights = private$.get_weights(task)
invoke(rpart::rpart, formula = task$formula(), data = task$data(), .args = pv, .opts = allow_partial_matching)
},

Expand Down
8 changes: 2 additions & 6 deletions R/LearnerRegrRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
minsplit = p_int(1L, default = 20L, tags = "train"),
surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"),
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
xval = p_int(0L, default = 10L, init = 0L, tags = "train")
)
ps$values = list(xval = 0L)

super$initialize(
id = "regr.rpart",
Expand Down Expand Up @@ -77,10 +76,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

pv$weights = private$.get_weights(task)
invoke(rpart::rpart, formula = task$formula(), data = task$data(), .args = pv, .opts = allow_partial_matching)
},

Expand Down
37 changes: 27 additions & 10 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
#' In such cases it is necessary to overwrite the public methods `$aggregate()` and/or `$score()` to return a named `numeric()`
#' where at least one of its names corresponds to the `id` of the measure itself.
#'
#' @section Weights:
#'
#' Many measures support observation weights, indicated by their property `"weights"`.
#' The weights are stored in the [Task] where the column role `weights_measure` needs to be assigned to a single numeric column.
#' The weights are automatically used if found in the task, this can be disabled by setting the hyperparamerter `use_weights` to `FALSE`.
#' If the measure is set-up to use weights but the task does not have a designated weight column, an unweighted version is calculated instead.
#' The weights do not necessarily need to sum up to 1, they are normalized by dividing by the sum of weights.
#'
#' @template param_id
#' @template param_param_set
#' @template param_range
Expand Down Expand Up @@ -94,10 +102,6 @@ Measure = R6Class("Measure",
#' Lower and upper bound of possible performance scores.
range = NULL,

#' @field properties (`character()`)\cr
#' Properties of this measure.
properties = NULL,

#' @field minimize (`logical(1)`)\cr
#' If `TRUE`, good predictions correspond to small values of performance scores.
minimize = NULL,
Expand All @@ -117,7 +121,6 @@ Measure = R6Class("Measure",
predict_sets = "test", task_properties = character(), packages = character(),
label = NA_character_, man = NA_character_, trafo = NULL) {

self$properties = unique(properties)
self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = task_type
Expand All @@ -140,6 +143,8 @@ Measure = R6Class("Measure",
assert_subset(task_properties, mlr_reflections$task_properties[[task_type]])
}


self$properties = unique(properties)
self$predict_type = predict_type
self$predict_sets = predict_sets
self$task_properties = task_properties
Expand Down Expand Up @@ -195,24 +200,25 @@ Measure = R6Class("Measure",
#' @return `numeric(1)`.
score = function(prediction, task = NULL, learner = NULL, train_set = NULL) {
assert_scorable(self, task = task, learner = learner, prediction = prediction)
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties)
properties = self$properties
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% properties)

# check should be added to assert_measure()
# except when the checks are superfluous for rr$score() and bmr$score()
# these checks should be added bellow
if ("requires_task" %in% self$properties && is.null(task)) {
if ("requires_task" %in% properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
}

if ("requires_learner" %in% self$properties && is.null(learner)) {
if ("requires_learner" %in% properties && is.null(learner)) {
stopf("Measure '%s' requires a learner", self$id)
}

if (!is_scalar_na(self$task_type) && self$task_type != prediction$task_type) {
stopf("Measure '%s' incompatible with task type '%s'", self$id, prediction$task_type)
}

if ("requires_train_set" %in% self$properties && is.null(train_set)) {
if ("requires_train_set" %in% properties && is.null(train_set)) {
stopf("Measure '%s' requires the train_set", self$id)
}

Expand All @@ -227,7 +233,6 @@ Measure = R6Class("Measure",
#'
#' @return `numeric(1)`.
aggregate = function(rr) {

switch(self$average,
"macro" = {
aggregator = self$aggregator %??% mean
Expand Down Expand Up @@ -274,6 +279,17 @@ Measure = R6Class("Measure",
self$predict_sets, mget(private$.extra_hash, envir = self))
},

#' @field properties (`character()`)\cr
#' Properties of this measure.
properties = function(rhs) {
if (!missing(rhs)) {
props = if (is.na(self$task_type)) unique(unlist(mlr_reflections$measure_properties), use.names = FALSE) else mlr_reflections$measure_properties[[self$task_type]]
private$.properties = assert_subset(rhs, props)
} else {
private$.properties
}
},

#' @field average (`character(1)`)\cr
#' Method for aggregation:
#'
Expand Down Expand Up @@ -306,6 +322,7 @@ Measure = R6Class("Measure",
),

private = list(
.properties = character(),
.predict_sets = NULL,
.extra_hash = character(),
.average = NULL,
Expand Down
2 changes: 2 additions & 0 deletions R/MeasureRegrRSQ.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ MeasureRegrRSQ = R6Class("MeasureRSQ",
),

private = list(
# this is not included in the paramset as this flag influences properties of the learner
# so this flag should not be "dynamic state"
.pred_set_mean = NULL,

.score = function(prediction, task = NULL, train_set = NULL, ...) {
Expand Down
Loading