Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mlr-org/mlr3mbo
Browse files Browse the repository at this point in the history
  • Loading branch information
sumny committed Nov 28, 2024
2 parents 4355e74 + b8253c2 commit 917ff4e
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 18 deletions.
10 changes: 6 additions & 4 deletions R/ResultAssignerSurrogate.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ ResultAssignerSurrogate = R6Class("ResultAssignerSurrogate",
extra = xydt[, !c(cols_x, cols_y), with = FALSE]

# ys are still the ones originally evaluated
best_y = if (inherits(instance, c("OptimInstanceBatchSingleCrit", "OptimInstanceAsyncSingleCrit"))) {
unlist(archive$data[best, on = cols_x][, cols_y, with = FALSE])
if (inherits(instance, c("OptimInstanceBatchSingleCrit", "OptimInstanceAsyncSingleCrit"))) {
best_y = unlist(archive$data[best, on = cols_x][, cols_y, with = FALSE])
instance$assign_result(xdt = best, y = best_y, extra = extra)
} else if (inherits(instance, c("OptimInstanceBatchMultiCrit", "OptimInstanceAsyncMultiCrit"))) {
archive$data[best, on = cols_x][, cols_y, with = FALSE]
best_y = archive$data[best, on = cols_x][, cols_y, with = FALSE]
instance$assign_result(xdt = best, ydt = best_y, extra = extra)
}
instance$assign_result(xdt = best, y = best_y, extra = extra)

}
),

Expand Down
19 changes: 16 additions & 3 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
old_threshold_bbotk = lg$threshold
lg$set_threshold("warn")
old_opts = options(
warnPartialMatchArgs = TRUE,
warnPartialMatchAttr = TRUE,
warnPartialMatchDollar = TRUE
)

# https://github.com/HenrikBengtsson/Wishlist-for-R/issues/88
old_opts = lapply(old_opts, function(x) if (is.null(x)) FALSE else x)

lg_mlr3 = lgr::get_logger("mlr3")
lg_bbotk = lgr::get_logger("bbotk")
lg_rush = lgr::get_logger("rush")

old_threshold_mlr3 = lg_mlr3$threshold
lg_mlr3$set_threshold("warn")
old_threshold_bbotk = lg_bbotk$threshold
old_threshold_rush = lg_rush$threshold

lg_mlr3$set_threshold(0)
lg_bbotk$set_threshold(0)
lg_rush$set_threshold(0)
5 changes: 3 additions & 2 deletions tests/testthat/teardown.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
lg$set_threshold(old_threshold_bbotk)
options(old_opts)
lg_mlr3$set_threshold(old_threshold_mlr3)

lg_bbotk$set_threshold(old_threshold_bbotk)
lg_rush$set_threshold(old_threshold_rush)
2 changes: 1 addition & 1 deletion tests/testthat/test_ResultAssignerArchive.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ test_that("ResultAssignerArchive works with OptimizerMbo and bayesopt_ego", {
optimizer = opt("mbo", loop_function = bayesopt_ego, surrogate = surrogate, acq_function = acq_function, acq_optimizer = acq_optimizer, result_assigner = result_assigner)
optimizer$optimize(instance)
expect_true(nrow(instance$archive$data) == 5L)
expect_data_table(instance$result, nrow = 1L)
expect_data_table(instance$result, nrows = 1L)
})

test_that("ResultAssignerArchive works with OptimizerMbo and bayesopt_parego", {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_ResultAssignerSurrogate.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ test_that("ResultAssignerSurrogate works with OptimizerMbo and bayesopt_ego", {

expect_r6(result_assigner$surrogate, classes = "SurrogateLearner")
expect_r6(result_assigner$surrogate$learner, classes = "Learner")
expect_data_table(instance$result, nrow = 1L)
expect_data_table(instance$result, nrows = 1L)
})

test_that("ResultAssignerSurrogate works with OptimizerMbo and bayesopt_parego", {
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/test_SurrogateLearnerCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ test_that("predict_types are recognized", {
learner1$predict_type = "se"
learner2 = lrn("regr.rpart")
learner2$predict_type = "response"
surrogate = SurrogateLearnerCollection$new(learner = list(learner1, learner2), archive = inst$archive)
surrogate = SurrogateLearnerCollection$new(learners = list(learner1, learner2), archive = inst$archive)
surrogate$update()

xdt = data.table(x = seq(-1, 1, length.out = 5L))
Expand All @@ -58,7 +58,7 @@ test_that("predict_types are recognized", {

test_that("param_set", {
inst = MAKE_INST(OBJ_1D_2, PS_1D, trm("evals", n_evals = 5L))
surrogate = SurrogateLearnerCollection$new(learner = list(REGR_FEATURELESS, REGR_FEATURELESS$clone(deep = TRUE)), archive = inst$archive)
surrogate = SurrogateLearnerCollection$new(learners = list(REGR_FEATURELESS, REGR_FEATURELESS$clone(deep = TRUE)), archive = inst$archive)
expect_r6(surrogate$param_set, "ParamSet")
expect_setequal(surrogate$param_set$ids(), c("assert_insample_perf", "perf_measures", "perf_thresholds", "catch_errors", "impute_method"))
expect_equal(surrogate$param_set$class[["assert_insample_perf"]], "ParamLgl")
Expand All @@ -77,7 +77,7 @@ test_that("insample_perf", {
design = MAKE_DESIGN(inst)
inst$eval_batch(design)

surrogate = SurrogateLearnerCollection$new(learner = list(REGR_KM_DETERM, REGR_KM_DETERM$clone(deep = TRUE)), archive = inst$archive)
surrogate = SurrogateLearnerCollection$new(learners = list(REGR_KM_DETERM, REGR_KM_DETERM$clone(deep = TRUE)), archive = inst$archive)
expect_error({surrogate$insample_perf = c(0, 0)}, regexp = "insample_perf is read-only.")
expect_error({surrogate$assert_insample_perf = 0}, regexp = "assert_insample_perf is read-only.")

Expand All @@ -91,7 +91,7 @@ test_that("insample_perf", {
expect_double(surrogate$insample_perf, lower = -Inf, upper = 1, any.missing = FALSE, len = 2L)
expect_equal(names(surrogate$insample_perf), map_chr(surrogate$param_set$values$perf_measures, "id"))

surrogate_constant = SurrogateLearnerCollection$new(learner = list(REGR_FEATURELESS, REGR_FEATURELESS$clone(deep = TRUE)), archive = inst$archive)
surrogate_constant = SurrogateLearnerCollection$new(learners = list(REGR_FEATURELESS, REGR_FEATURELESS$clone(deep = TRUE)), archive = inst$archive)
surrogate_constant$param_set$values$assert_insample_perf = TRUE
surrogate_constant$param_set$values$perf_thresholds = c(0.5, 0.5)
surrogate_constant$param_set$values$perf_measures = list(mlr_measures$get("regr.rsq"), mlr_measures$get("regr.rsq"))
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_TunerADBO.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ test_that("TunerADBO works", {
task = tsk("pima"),
learner = learner,
resampling = rsmp("cv", folds = 3L),
measure = msr("classif.ce"),
measures = msr("classif.ce"),
terminator = trm("evals", n_evals = 20L),
store_benchmark_result = FALSE
)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_TunerAsyncMbo.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ test_that("TunerAsyncMbo works", {
task = tsk("pima"),
learner = learner,
resampling = rsmp("cv", folds = 3L),
measure = msr("classif.ce"),
measures = msr("classif.ce"),
terminator = trm("evals", n_evals = 20L),
store_benchmark_result = FALSE
)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_TunerMbo.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ test_that("TunerMbo reset", {
instance = TuningInstanceBatchSingleCrit$new(tsk("iris"), learner = learner, resampling = rsmp("holdout"), measure = msr("classif.ce"), terminator = trm("evals", n_evals = 5L))
tuner$optimize(instance)

instance_mult = TuningInstanceBatchMultiCrit$new(tsk("iris"), learner = learner, resampling = rsmp("holdout"), measure = msrs(c("classif.ce", "classif.logloss")), terminator = trm("evals", n_evals = 5L))
instance_mult = TuningInstanceBatchMultiCrit$new(tsk("iris"), learner = learner, resampling = rsmp("holdout"), measures = msrs(c("classif.ce", "classif.logloss")), terminator = trm("evals", n_evals = 5L))

expect_error(tuner$optimize(instance_mult), "does not support multi-crit objectives")
expect_loop_function(tuner$loop_function)
Expand Down

0 comments on commit 917ff4e

Please sign in to comment.