Skip to content

Commit

Permalink
Unify forests 1 (#41)
Browse files Browse the repository at this point in the history
* core functions

* rename tree_depth function

* simplify min_depth_distribution.R

* simplify measure_importance.R

* simplify min_depth_interaction.R

* Update NAMESPACE

* Update rda files of vignette

* remove two very fast to calculate temp results
  • Loading branch information
mayer79 authored Mar 22, 2024
1 parent 96572e0 commit c92335e
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 306 deletions.
4 changes: 0 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@

S3method(measure_importance,randomForest)
S3method(measure_importance,ranger)
S3method(min_depth_distribution,randomForest)
S3method(min_depth_distribution,ranger)
S3method(min_depth_interactions,randomForest)
S3method(min_depth_interactions,ranger)
S3method(plot_predict_interaction,randomForest)
S3method(plot_predict_interaction,ranger)
export(explain_forest)
Expand Down
55 changes: 14 additions & 41 deletions R/measure_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,11 @@ measure_min_depth <- function(min_depth_frame, mean_sample){
}

# Calculate the number of nodes split on each variable for a data frame with the whole forest
# randomForest
measure_no_of_nodes <- function(forest_table){
`split var` <- NULL
frame <- dplyr::group_by(forest_table, variable = `split var`) %>% dplyr::summarize(no_of_nodes = dplyr::n())
frame <- as.data.frame(frame[!is.na(frame$variable),])
frame$variable <- as.character(frame$variable)
return(frame)
}

# Calculate the number of nodes split on each variable for a data frame with the whole forest
# randomForest
measure_no_of_nodes_ranger <- function(forest_table){
splitvarName <- NULL
frame <- dplyr::group_by(forest_table, variable = splitvarName) %>% dplyr::summarize(no_of_nodes = n())
frame <- dplyr::group_by(forest_table, variable = variable) %>%
dplyr::summarize(no_of_nodes = dplyr::n())
frame <- as.data.frame(frame[!is.na(frame$variable),])
frame$variable <- as.character(frame$variable)
return(frame)
}

Expand Down Expand Up @@ -73,17 +62,18 @@ measure_vimp_ranger <- function(forest){
measure_no_of_trees <- function(min_depth_frame){
variable <- NULL
frame <- dplyr::group_by(min_depth_frame, variable) %>%
dplyr::summarize(no_of_trees = n()) %>% as.data.frame()
frame$variable <- as.character(frame$variable)
dplyr::summarize(no_of_trees = n()) %>%
as.data.frame()
return(frame)
}

# Calculate the number of times each variable is split on the root node
measure_times_a_root <- function(min_depth_frame){
variable <- NULL
frame <- min_depth_frame[min_depth_frame$minimal_depth == 0, ] %>%
dplyr::group_by(variable) %>% dplyr::summarize(times_a_root = n()) %>% as.data.frame()
frame$variable <- as.character(frame$variable)
dplyr::group_by(variable) %>%
dplyr::summarize(times_a_root = n()) %>%
as.data.frame()
return(frame)
}

Expand Down Expand Up @@ -142,18 +132,9 @@ measure_importance.randomForest <- function(forest, mean_sample = "top_trees", m
if (is.null(forest$forest)) {
stop("Make sure forest has been saved when calling randomForest by randomForest(..., keep.forest = TRUE).")
}
forest_table <- lapply(
1:forest$ntree,
function(i)
randomForest::getTree(forest, k = i, labelVar = TRUE) %>%
mutate(`split var` = as.character(`split var`)) %>%
calculate_tree_depth() %>%
cbind(tree = i)
) %>%
rbindlist()
min_depth_frame <- dplyr::group_by(forest_table, tree, `split var`) %>%
dplyr::summarize(min(depth))
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
forest_table <- forest2df(forest)
min_depth_frame <- dplyr::group_by(forest_table, tree, variable) %>%
dplyr::summarize(minimal_depth = min(depth), .groups = "drop")
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
}
# Add each importance measure to the table (if it was requested)
Expand Down Expand Up @@ -206,25 +187,17 @@ measure_importance.ranger <- function(forest, mean_sample = "top_trees", measure
importance_frame <- data.frame(variable = names(forest$variable.importance), stringsAsFactors = FALSE)
# Get objects necessary to calculate importance measures based on the tree structure
if(any(c("mean_min_depth", "no_of_nodes", "no_of_trees", "times_a_root", "p_value") %in% measures)){
forest_table <- lapply(
1:forest$num.trees,
function(i)
ranger::treeInfo(forest, tree = i) %>%
calculate_tree_depth_ranger() %>%
cbind(tree = i)
) %>%
rbindlist()
min_depth_frame <- dplyr::group_by(forest_table, tree, splitvarName) %>%
dplyr::summarize(min(depth))
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
forest_table <- forest2df(forest)
min_depth_frame <- dplyr::group_by(forest_table, tree, variable) %>%
dplyr::summarize(minimal_depth = min(depth), .groups = "drop")
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
}
# Add each importance measure to the table (if it was requested)
if("mean_min_depth" %in% measures){
importance_frame <- merge(importance_frame, measure_min_depth(min_depth_frame, mean_sample), all = TRUE)
}
if("no_of_nodes" %in% measures){
importance_frame <- merge(importance_frame, measure_no_of_nodes_ranger(forest_table), all = TRUE)
importance_frame <- merge(importance_frame, measure_no_of_nodes(forest_table), all = TRUE)
importance_frame[is.na(importance_frame$no_of_nodes), "no_of_nodes"] <- 0
}
if(forest$importance.mode %in% measures){
Expand Down
93 changes: 8 additions & 85 deletions R/min_depth_distribution.R
Original file line number Diff line number Diff line change
@@ -1,48 +1,3 @@
# Calculate the depth of each node in a single tree obtained from a forest with randomForest::getTree
calculate_tree_depth <- function(frame){
if(!all(c("right daughter", "left daughter") %in% names(frame))){
stop("The data frame has to contain columns called 'right daughter' and 'left daughter'!
It should be a product of the function getTree(..., labelVar = T).")
}
frame[["depth"]] <- calculate_tree_depth_(
frame[, c("left daughter", "right daughter")]
)
return(frame)
}

# Calculate the depth of each node in a single tree obtained from a forest with ranger::treeInfo
calculate_tree_depth_ranger <- function(frame){
if(!all(c("rightChild", "leftChild") %in% names(frame))){
stop("The data frame has to contain columns called 'rightChild' and 'leftChild'!
It should be a product of the function ranger::treeInfo().")
}
# Child nodes are zero based, so we increase them by 1
frame[["depth"]] <- calculate_tree_depth_(
frame[, c("leftChild", "rightChild")] + 1
)
return(frame)
}

# Internal function used to determine the depth of each node.
# The input is a data.frame with left and right child nodes in 1:nrow(childs).
calculate_tree_depth_ <- function(childs) {
childs <- as.matrix(childs)
n <- nrow(childs)
depth <- rep(NA, times = n)
j <- depth[1L] <- 0
ix <- 1L # current nodes, initialized with root node index

# j loops over tree depth
while(anyNA(depth) && j < n) { # The second condition is never used
ix <- as.integer(childs[ix, ])
ix <- ix[!is.na(ix) & ix >= 1L] # leaf nodes do not have childs
j <- j + 1
depth[ix] <- j
}

return(depth)
}

#' Calculate minimal depth distribution of a random forest
#'
#' Get minimal depth values for all trees in a random forest
Expand All @@ -56,47 +11,13 @@ calculate_tree_depth_ <- function(childs) {
#' min_depth_distribution(ranger::ranger(Species ~ ., data = iris, num.trees = 100))
#'
#' @export
min_depth_distribution <- function(forest){
UseMethod("min_depth_distribution")
}

#' @import dplyr
#' @importFrom data.table rbindlist
#' @export
min_depth_distribution.randomForest <- function(forest){
min_depth_distribution <- function(forest){
tree <- NULL; `split var` <- NULL; depth <- NULL
forest_table <- lapply(
1:forest$ntree,
function(i)
randomForest::getTree(forest, k = i, labelVar = TRUE) %>%
mutate(`split var` = as.character(`split var`)) %>%
calculate_tree_depth() %>%
cbind(tree = i)
) %>%
rbindlist()
min_depth_frame <- dplyr::group_by(forest_table, tree, `split var`) %>%
dplyr::summarize(min(depth))
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
return(min_depth_frame)
}

#' @import dplyr
#' @importFrom data.table rbindlist
#' @export
min_depth_distribution.ranger <- function(forest){
tree <- NULL; splitvarName <- NULL; depth <- NULL
forest_table <- lapply(
1:forest$num.trees,
function(i)
ranger::treeInfo(forest, tree = i) %>%
calculate_tree_depth_ranger() %>%
cbind(tree = i)
) %>%
rbindlist()
min_depth_frame <- dplyr::group_by(forest_table, tree, splitvarName) %>%
dplyr::summarize(min(depth))
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
forest_table <- forest2df(forest)
min_depth_frame <- dplyr::group_by(forest_table, tree, variable) %>%
dplyr::summarize(minimal_depth = min(depth), .groups = "drop")
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
return(min_depth_frame)
}
Expand All @@ -105,10 +26,12 @@ min_depth_distribution.ranger <- function(forest){
min_depth_count <- function(min_depth_frame){
tree <- NULL; minimal_depth <- NULL; variable <- NULL
mean_tree_depth <- dplyr::group_by(min_depth_frame, tree) %>%
dplyr::summarize(depth = max(minimal_depth) + 1) %>% as.data.frame()
dplyr::summarize(depth = max(minimal_depth) + 1) %>%
as.data.frame()
mean_tree_depth <- mean(mean_tree_depth$depth)
min_depth_count <- dplyr::group_by(min_depth_frame, variable, minimal_depth) %>%
dplyr::summarize(count = n()) %>% as.data.frame()
dplyr::summarize(count = n(), .groups = "drop") %>%
as.data.frame()
occurrences <- stats::aggregate(count ~ variable, data = min_depth_count, sum)
colnames(occurrences)[2] <- "no_of_occurrences"
min_depth_count <-
Expand Down
Loading

0 comments on commit c92335e

Please sign in to comment.