diff --git a/NEWS.md b/NEWS.md index 2ef616b..7465a11 100644 --- a/NEWS.md +++ b/NEWS.md @@ -5,6 +5,10 @@ - Factor-valued predictions are no longer possible. - Consequently, also removed "classification_error" loss. +## Minor changes + +- Code simplifications. + # hstats 1.1.2 ## ICE plots diff --git a/R/losses.R b/R/losses.R index b207b5b..f9b40de 100644 --- a/R/losses.R +++ b/R/losses.R @@ -12,8 +12,8 @@ check_loss <- function(actual, predicted) { stopifnot( is.vector(actual) || is.matrix(actual), is.vector(predicted) || is.matrix(predicted), - is.numeric(actual), - is.numeric(predicted), + is.numeric(actual) || is.logical(actual), + is.numeric(predicted) || is.logical(predicted), NROW(actual) == NROW(predicted), NCOL(actual) == 1L || NCOL(actual) == NCOL(predicted) ) @@ -27,10 +27,14 @@ check_loss <- function(actual, predicted) { #' @noRd #' @keywords internal #' -#' @param actual A numeric vector or matrix. +#' @param actual A numeric vector or matrix, or a factor with levels in the same order +#' as the column names of `predicted`. #' @param predicted A numeric vector or matrix. #' @returns Vector or matrix of numeric losses. loss_squared_error <- function(actual, predicted) { + if (is.factor(actual)) { + actual <- fdummy(actual) + } check_loss(actual, predicted) return((drop(actual) - predicted)^2) @@ -147,11 +151,10 @@ loss_mlogloss <- function(actual, predicted) { is.matrix(actual), is.matrix(predicted), - is.numeric(actual), - is.numeric(predicted), + is.numeric(actual) || is.logical(actual), + is.numeric(predicted) || is.logical(predicted), - nrow(actual) == nrow(predicted), - ncol(actual) == ncol(predicted), + dim(actual) == dim(predicted), ncol(predicted) >= 2L, all(predicted >= 0), @@ -176,6 +179,7 @@ loss_mlogloss <- function(actual, predicted) { xlogy <- function(x, y) { out <- x * log(y) out[x == 0] <- 0 + return(out) } diff --git a/R/perm_importance.R b/R/perm_importance.R index 5df1b02..20fdcfb 100644 --- a/R/perm_importance.R +++ b/R/perm_importance.R @@ -103,9 +103,9 @@ perm_importance.default <- function(object, X, y, v = NULL, if (nrow(X) > n_max) { ix <- sample(nrow(X), n_max) X <- X[ix, , drop = FALSE] - if (is.vector(y)) { + if (is.vector(y) || is.factor(y)) { y <- y[ix] - } else { + } else { # matrix case y <- y[ix, , drop = FALSE] } if (!is.null(w)) { @@ -126,9 +126,9 @@ perm_importance.default <- function(object, X, y, v = NULL, if (m_rep > 1L) { ind <- rep.int(seq_len(n), m_rep) X <- rep_rows(X, ind) - if (is.vector(y)) { + if (is.vector(y) || is.factor(y)) { y <- y[ind] - } else { + } else { # matrix case y <- y[ind, , drop = FALSE] } } diff --git a/R/utils_input.R b/R/utils_input.R index d3d162f..d0845e3 100644 --- a/R/utils_input.R +++ b/R/utils_input.R @@ -14,7 +14,7 @@ prepare_pred <- function(x) { if (!is.vector(x) && !is.matrix(x)) { x <- as.matrix(x) } - if (!is.numeric(x)) { + if (!is.numeric(x) && !is.logical(x)) { stop("Predictions must be numeric!") } return(x) @@ -116,13 +116,10 @@ prepare_y <- function(y, X) { stopifnot(NROW(y) == nrow(X)) y_names <- NULL } - if (is.factor(y)) { - y <- fdummy(y) - } - if (!is.vector(y) && !is.matrix(y)) { + if (!is.vector(y) && !is.matrix(y) && !is.factor(y)) { y <- as.matrix(y) } - if (!is.numeric(y)) { + if (!is.numeric(y) && !is.logical(y) && !is.factor(y)) { stop("Response must be numeric (or factor.)") } list(y = y, y_names = y_names) diff --git a/backlog/benchmark.R b/backlog/benchmark.R index c367b67..3923a1c 100644 --- a/backlog/benchmark.R +++ b/backlog/benchmark.R @@ -115,12 +115,11 @@ bench::mark( check = FALSE, min_iterations = 3 ) - -# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time -# 1 iml 1.72s 1.75s 0.574 210.6MB 1.34 3 7 5.23s -# 2 dalex 744.82ms 760.02ms 1.31 35.2MB 0.877 3 2 2.28s -# 3 flashlight 1.29s 1.35s 0.742 63MB 0.990 3 4 4.04s -# 4 hstats 407.26ms 412.31ms 2.43 26.5MB 0 3 0 1.23s +# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time result +# 1 iml 1.76s 1.76s 0.565 211.6MB 3.39 3 18 5.31s +# 2 dalex 688.54ms 697.71ms 1.44 35.2MB 1.91 3 4 2.09s +# 3 flashlight 667.51ms 676.07ms 1.47 28.1MB 1.96 3 4 2.04s +# 4 hstats 392.15ms 414.41ms 2.39 26.6MB 0.796 3 1 1.26s # Partial dependence (cont) v <- "tot_lvg_area" @@ -132,12 +131,12 @@ bench::mark( check = FALSE, min_iterations = 3 ) -# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time -# 1 iml 1.14s 1.16s 0.861 376.7MB 3.73 3 13 3.48s -# 2 dalex 653.24ms 654.51ms 1.35 192.8MB 2.24 3 5 2.23s -# 3 flashlight 352.34ms 361.79ms 2.72 66.7MB 0.906 3 1 1.1s -# 4 hstats 239.03ms 242.79ms 4.04 14.2MB 1.35 3 1 743.43ms - +# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time result +# +# 1 iml 1.2s 1.4s 0.726 376.9MB 4.12 3 17 4.13s +# 2 dalex 759.3ms 760.6ms 1.28 192.8MB 2.55 3 6 2.35s +# 3 flashlight 369.1ms 403.1ms 2.55 66.8MB 2.55 3 3 1.18s +# 4 hstats 242.1ms 243.8ms 4.03 14.2MB 0 3 0 744.25ms # # Partial dependence (discrete) v <- "structure_quality" bench::mark( @@ -148,30 +147,31 @@ bench::mark( check = FALSE, min_iterations = 3 ) -# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time -# 1 iml 100.6ms 103.6ms 9.46 13.34MB 0 5 0 529ms -# 2 dalex 172.4ms 177.9ms 5.62 20.55MB 2.81 2 1 356ms -# 3 flashlight 43.5ms 45.5ms 21.9 6.36MB 2.19 10 1 457ms -# 4 hstats 25.3ms 25.8ms 37.9 1.54MB 2.10 18 1 475ms - +# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time result +# +# 1 iml 107.9ms 108ms 9.26 13.64MB 9.26 2 2 216ms +# 2 dalex 172ms 172.2ms 5.81 21.14MB 2.90 2 1 344ms +# 3 flashlight 40.3ms 41.6ms 23.8 8.61MB 2.16 11 1 462ms +# 4 hstats 24.5ms 25.9ms 35.5 1.64MB 0 18 0 507ms + # H-Stats -> we use a subset of 500 rows X_v500 <- X_valid[1:500, ] mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), predict.function = predf) fl500 <- flashlight(fl, data = as.data.frame(valid[1:500, ])) -# iml # 225s total, using slow exact calculations -system.time( # 90s +# iml # 243s total, using slow exact calculations +system.time( # 110s iml_overall <- Interaction$new(mod500, grid.size = 500) ) -system.time( # 135s for all combinations of latitude +system.time( # 133s for all combinations of latitude iml_pairwise <- Interaction$new(mod500, grid.size = 500, feature = "latitude") ) -# flashlight: 13s total, doing only one pairwise calculation, otherwise would take 63s -system.time( # 11.5s +# flashlight: 14s total, doing only one pairwise calculation, otherwise would take 63s +system.time( # 11.7s fl_overall <- light_interaction(fl500, v = x, grid_size = Inf, n_max = Inf) ) -system.time( # 2.4s +system.time( # 2.3s fl_pairwise <- light_interaction( fl500, v = coord, grid_size = Inf, n_max = Inf, pairwise = TRUE ) @@ -185,34 +185,38 @@ system.time({ } ) -# Using 50 quantiles to approximate dense numerics: 0.9s +# Using 50 quantiles to approximate dense numerics: 0.8s system.time( H_approx <- hstats(fit, v = x, X = X_v500, n_max = Inf, approx = TRUE) ) # Overall statistics correspond exactly -iml_overall$results |> filter(.interaction > 1e-6) +iml_overall$results |> + filter(.interaction > 1e-6) # .feature .interaction -# 1: latitude 0.2791144 -# 2: longitude 0.2791144 +# 1: latitude 0.2458269 +# 2: longitude 0.2458269 -fl_overall$data |> subset(value > 0, select = c(variable, value)) -# variable value -# 1 latitude 0.279 -# 2 longitude 0.279 +fl_overall$data |> + subset(value_ > 0, select = c(variable_, value_)) +# variable_ value_ +# 3 latitude 0.2458269 +# 4 longitude 0.2458269 hstats_overall # longitude latitude -# 0.2791144 0.2791144 +# 0.2458269 0.2458269 # Pairwise results match as well -iml_pairwise$results |> filter(.interaction > 1e-6) +iml_pairwise$results |> + filter(.interaction > 1e-6) # .feature .interaction -# 1: longitude:latitude 0.4339574 +# 1: longitude:latitude 0.3942526 -fl_pairwise$data |> subset(value > 0, select = c(variable, value)) -# latitude:longitude 0.434 +fl_pairwise$data |> + subset(value_ > 0, select = c(variable_, value_)) +# latitude:longitude 0.3942526 hstats_pairwise # latitude:longitude -# 0.4339574 \ No newline at end of file +# 0.3942526 \ No newline at end of file diff --git a/backlog/colMeans_factors.R b/backlog/colMeans_factors.R deleted file mode 100644 index adcc54c..0000000 --- a/backlog/colMeans_factors.R +++ /dev/null @@ -1,27 +0,0 @@ -#' gcolMeans() for Factors -#' -#' Grouped version of `colMeans_factor()`. -#' -#' @noRd -#' @keywords internal -#' -#' @params x Factor. -#' @returns Named vector. -gcolMeans_factor <- function(x, g = NULL) { - if (is.null(g)) { - colMeans_factor(x) - } - x <- as.factor(x) - out <- t.default(sapply(split.default(x, g), colMeans_factor)) - colnames(out) <- levels(x) - out -} - -wrowmean_matrix2 <- function(x, ngroups = 1L, w = NULL) { - if (!is.matrix(x)) { - stop("x must be a matrix.") - } - dim(x) <- c(nrow(x)/ngroups, ngroups, ncol(x)) - out <- colMeans(aperm(x, c(1, 3, 2))) - t.default(out) -} diff --git a/tests/testthat/test_average_loss.R b/tests/testthat/test_average_loss.R index a6a7786..fd8d1c6 100644 --- a/tests/testthat/test_average_loss.R +++ b/tests/testthat/test_average_loss.R @@ -230,7 +230,7 @@ test_that("mlogloss works with either matrix y or factor y", { }) test_that("loss_mlogloss() is in line with loss_logloss() in binary case", { - y <- (iris$Species == "setosa") * 1 + y <- (iris$Species == "setosa") Y <- cbind(no = 1 - y, yes = y) fit <- glm(y ~ Sepal.Length, data = iris, family = binomial()) pf <- function(m, X, multi = FALSE) { diff --git a/tests/testthat/test_input.R b/tests/testthat/test_input.R index f8a1910..3bb5616 100644 --- a/tests/testthat/test_input.R +++ b/tests/testthat/test_input.R @@ -24,7 +24,7 @@ test_that("prepare_by() works", { test_that("prepare_y() works", { # "Vector" interface expect_equal(prepare_y(iris[1:4], X = iris)$y, data.matrix(iris[1:4])) - expect_equal(prepare_y(iris["Species"], X = iris)$y, fdummy(iris$Species)) + expect_equal(prepare_y(iris["Species"], X = iris)$y, iris$Species) expect_equal(prepare_y(iris$Sepal.Width, X = iris)$y, iris$Sepal.Width) expect_equal(prepare_y(iris["Sepal.Width"], X = iris)$y, iris$Sepal.Width) @@ -35,7 +35,7 @@ test_that("prepare_y() works", { expect_equal(out$y_names, cn) out <- prepare_y("Species", X = iris) - expect_equal(out$y, fdummy(iris$Species)) + expect_equal(out$y, iris$Species) expect_equal(out$y_names, "Species") out <- prepare_y("Sepal.Width", X = iris) diff --git a/tests/testthat/test_perm_importance.R b/tests/testthat/test_perm_importance.R index 3cd4508..15247a8 100644 --- a/tests/testthat/test_perm_importance.R +++ b/tests/testthat/test_perm_importance.R @@ -359,7 +359,7 @@ test_that("Single output multiple models works without recycling y", { }) test_that("loss_mlogloss() is in line with loss_logloss() in binary case", { - y <- (iris$Species == "setosa") * 1 + y <- (iris$Species == "setosa") Y <- cbind(no = 1 - y, yes = y) fit <- glm(y ~ Sepal.Length, data = iris, family = binomial()) pf <- function(m, X, multi = FALSE) {