From 749ce3e578887b6859016bddb2e10ee1417de10f Mon Sep 17 00:00:00 2001 From: Sky Qiu <36117216+tq21@users.noreply.github.com> Date: Fri, 19 Apr 2024 14:50:50 -0700 Subject: [PATCH 1/3] k-means clustering based knot-point screening --- NAMESPACE | 2 + R/make_basis.R | 134 ++++++++++++++++++++++++++++++++++++++++++++----- R/predict.R | 3 +- 3 files changed, 124 insertions(+), 15 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index f5c33f1f..b28c029d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -24,6 +24,7 @@ importFrom(data.table,`:=`) importFrom(data.table,data.table) importFrom(data.table,rbindlist) importFrom(data.table,setorder) +importFrom(data.table,uniqueN) importFrom(glmnet,cv.glmnet) importFrom(glmnet,glmnet) importFrom(methods,is) @@ -37,6 +38,7 @@ importFrom(origami,validation) importFrom(stats,aggregate) importFrom(stats,as.formula) importFrom(stats,coef) +importFrom(stats,kmeans) importFrom(stats,median) importFrom(stats,plogis) importFrom(stats,predict) diff --git a/R/make_basis.R b/R/make_basis.R index d8cd1de2..2408d8c9 100644 --- a/R/make_basis.R +++ b/R/make_basis.R @@ -28,9 +28,14 @@ #' @return A \code{list} containing the basis functions generated from a set of #' input columns. basis_list_cols <- function(cols, x, smoothness_orders, include_zero_order, - include_lower_order = FALSE) { + include_lower_order = FALSE, ignore_cols = FALSE) { + # first, subset only to columns of interest - x_sub <- x[, cols, drop = FALSE] + if (ignore_cols) { + x_sub <- x + } else { + x_sub <- x[, cols, drop = FALSE] + } # call Rcpp routine to produce the list of basis functions basis_list <- make_basis_list(x_sub, cols, smoothness_orders) @@ -51,12 +56,12 @@ basis_list_cols <- function(cols, x, smoothness_orders, include_zero_order, new_smoothness_orders[col] <- new_smoothness_orders[col] - 1 return(basis_list_cols(cols, x, new_smoothness_orders, - include_zero_order, - include_lower_order = TRUE + include_zero_order, + include_lower_order = TRUE )) }) basis_list <- union(basis_list, unlist(more_basis_list, - recursive = FALSE + recursive = FALSE )) } } @@ -98,7 +103,7 @@ basis_list_cols <- function(cols, x, smoothness_orders, include_zero_order, #' @return A \code{list} containing basis functions and cutoffs generated from #' a set of input columns up to a particular pre-specified degree. basis_of_degree <- function(x, degree, smoothness_orders, include_zero_order, - include_lower_order) { + include_lower_order, x_col_idx = NULL) { # get dimensionality of input matrix p <- ncol(x) @@ -106,12 +111,22 @@ basis_of_degree <- function(x, degree, smoothness_orders, include_zero_order, if (degree > p) stop("The problem is not defined for degree > p.") # compute combinations of columns and generate a list of basis functions - all_cols <- utils::combn(p, degree) + if (is.null(x_col_idx)) { + all_cols <- utils::combn(p, degree) + } else { + # If x_col_idx is provided, that means knot-point screening algorithm is + # used, x could be a subset of original X. + # We need to make sure the column indices of x are correct. + # Note that in this case, ncol(x) must == degree. + all_cols <- matrix(x_col_idx) + } + all_basis_lists <- apply(all_cols, 2, basis_list_cols, - x = x, - smoothness_orders = smoothness_orders, - include_zero_order = include_zero_order, - include_lower_order = include_lower_order + x = x, + smoothness_orders = smoothness_orders, + include_zero_order = include_zero_order, + include_lower_order = include_lower_order, + ignore_cols = !is.null(x_col_idx) ) basis_list <- unlist(all_basis_lists, recursive = FALSE) @@ -147,6 +162,10 @@ basis_of_degree <- function(x, degree, smoothness_orders, include_zero_order, #' @param include_lower_order A \code{logical}, like \code{include_zero_order}, #' except including all basis functions of lower smoothness degrees than #' specified via \code{smoothness_orders}. +#' @param screen_knots A \code{logical}, indicating whether to screen the knot +#' points for each basis function. If \code{TRUE}, the knot points are screened +#' using k-mean clustering. The number of knots for each degree of basis +#' function is determined by the `num_knots` argument. #' @param num_knots A vector of length \code{max_degree}, which determines how #' granular the knot points to generate basis functions should be for each #' degree of basis function. The first entry of \code{num_knots} determines @@ -186,7 +205,8 @@ enumerate_basis <- function(x, smoothness_orders = rep(0, ncol(x)), include_zero_order = FALSE, include_lower_order = FALSE, - num_knots = NULL) { + num_knots = NULL, + screen_knots = FALSE) { if (!is.matrix(x)) { x <- as.matrix(x) } @@ -213,7 +233,21 @@ enumerate_basis <- function(x, } else { n_bin <- num_knots[degree] } - x <- quantizer(x, n_bin) + + if (screen_knots) { + basis_list <- kmeans_knot_screen( + X = x, + bins = n_bin, + degree = degree, + smoothness_orders = smoothness_orders, + include_zero_order = include_zero_order, + include_lower_order = include_lower_order) + + return(basis_list) + + } else { + x <- quantizer(x, n_bin) + } } return(basis_of_degree( x, degree, smoothness_orders, include_zero_order, @@ -338,3 +372,77 @@ quantizer <- function(X, bins) { } return(quantizer(X)) } + +#' @title Screen knot-points using k-means clustering +#' +#' @description Using k-mean clustering at each basis function level to screen +#' knot-points. An effective way to reduce the number of basis function and +#' improve the computational speed while maintaining good predictive +#' performance of HAL fit. +#' +#' @param X A \code{numeric} vector or matrix of input. +#' @param bins A \code{numeric} scalar indicating the number of knot-points for +#' each basis function. +#' @param degree A \code{numeric} scalar indicating the degree of the basis +#' functions. +#' @param smoothness_orders Argument for `enumerate_basis()` +#' @param include_zero_order Argument for `enumerate_basis()` +#' @param include_lower_order Argument for `enumerate_basis()` +#' +#' @importFrom stats kmeans +#' @importFrom data.table uniqueN +#' +#' @keywords internal +kmeans_knot_screen <- function(X, + bins, + degree, + smoothness_orders, + include_zero_order, + include_lower_order) { + if (is.null(bins)) { + return(X) + } + + if (!is.matrix(X)) { + X <- as.matrix(X) + } + + # function to screen knot-points for given x (subset of X) + screen_knots <- function(x) { + if (uniqueN(x) <= bins) { + return(x) + } + if (ncol(x) == 1 & all(x %in% c(0, 1))) { + return(rep(0, length(x))) + } + if (bins == 1) { + return(rep(min(x), length(x))) + } + + # k-means clustering of knot-points + k_means <- kmeans(x = x, centers = bins, algorithm = "MacQueen") + centroids <- k_means$centers + + return(centroids) + } + + # generate column index sets up to max_degree interactions + var_idx_num <- utils::combn(1:ncol(X), degree, simplify = FALSE) + + # screen knot-points for each column index set + basis_list <- unlist(lapply(var_idx_num, function(var_idx) { + X_sub <- X[, var_idx, drop = FALSE] + suppressWarnings(knots <- screen_knots(X_sub)) + + # make basis list for the screened knot-points + basis_of_degree(x = knots, + degree = degree, + smoothness_orders = smoothness_orders, + include_zero_order = include_zero_order, + include_lower_order = include_lower_order, + x_col_idx = var_idx) + + }), recursive = FALSE) + + return(basis_list) +} diff --git a/R/predict.R b/R/predict.R index 3ef9f3ad..b1d5e125 100644 --- a/R/predict.R +++ b/R/predict.R @@ -41,7 +41,6 @@ predict.hal9001 <- function(object, offset = NULL, type = c("response", "link"), ...) { - family <- ifelse(inherits(object$family, "family"), object$family$family, object$family) type <- match.arg(type) @@ -89,7 +88,7 @@ predict.hal9001 <- function(object, } else { preds <- as.vector(Matrix::tcrossprod( x = pred_x_basis, - y = object$coefs[-1] + y = matrix(object$coefs[-1], nrow = 1) ) + object$coefs[1]) } } else { From 2d11694aa0554d10b38911ad3c36e53c7d686ccc Mon Sep 17 00:00:00 2001 From: Sky Qiu <36117216+tq21@users.noreply.github.com> Date: Sun, 28 Jul 2024 11:25:17 -0700 Subject: [PATCH 2/3] minor fix for checking unique number of knots --- R/make_basis.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/make_basis.R b/R/make_basis.R index 2408d8c9..0425fee3 100644 --- a/R/make_basis.R +++ b/R/make_basis.R @@ -399,6 +399,7 @@ kmeans_knot_screen <- function(X, smoothness_orders, include_zero_order, include_lower_order) { + if (is.null(bins)) { return(X) } @@ -409,8 +410,8 @@ kmeans_knot_screen <- function(X, # function to screen knot-points for given x (subset of X) screen_knots <- function(x) { - if (uniqueN(x) <= bins) { - return(x) + if (nrow(unique(x)) <= bins) { + return(unique(x)) } if (ncol(x) == 1 & all(x %in% c(0, 1))) { return(rep(0, length(x))) From a81af69b28639221b4146ba8cfdcfcc83bb68aa2 Mon Sep 17 00:00:00 2001 From: Sky Qiu <36117216+tq21@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:05:46 -0700 Subject: [PATCH 3/3] pam --- DESCRIPTION | 2 +- NAMESPACE | 1 + R/make_basis.R | 105 +++++++++++++++++++++++++++++++++++--- man/basis_list_cols.Rd | 3 +- man/basis_of_degree.Rd | 3 +- man/enumerate_basis.Rd | 13 ++++- man/formula_helpers.Rd | 20 -------- man/generate_all_rules.Rd | 14 +++++ man/kmeans_knot_screen.Rd | 37 ++++++++++++++ man/pam_knot_screen.Rd | 37 ++++++++++++++ 10 files changed, 203 insertions(+), 32 deletions(-) delete mode 100644 man/formula_helpers.Rd create mode 100644 man/generate_all_rules.Rd create mode 100644 man/kmeans_knot_screen.Rd create mode 100644 man/pam_knot_screen.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 21c99e4b..7bf7308a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -68,5 +68,5 @@ LinkingTo: Rcpp, RcppEigen VignetteBuilder: knitr -RoxygenNote: 7.2.0 +RoxygenNote: 7.2.1 Roxygen: list(markdown = TRUE) diff --git a/NAMESPACE b/NAMESPACE index b28c029d..fab3d4a2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -20,6 +20,7 @@ import(Rcpp) importFrom(Matrix,tcrossprod) importFrom(Rcpp,sourceCpp) importFrom(assertthat,assert_that) +importFrom(cluster,pam) importFrom(data.table,`:=`) importFrom(data.table,data.table) importFrom(data.table,rbindlist) diff --git a/R/make_basis.R b/R/make_basis.R index 0425fee3..deea96e5 100644 --- a/R/make_basis.R +++ b/R/make_basis.R @@ -175,6 +175,9 @@ basis_of_degree <- function(x, degree, smoothness_orders, include_zero_order, #' for a kth degree basis function, which is the tensor product of k #' univariate basis functions, this determines the number of knot points to be #' used for each univariate basis function in the tensor product. +#' @param algorithm A \code{character} string specifying the algorithm to use +#' for knot point screening. The default is \code{"kmeans"}. The other option +#' is \code{"pam"}. #' #' @export #' @@ -206,7 +209,8 @@ enumerate_basis <- function(x, include_zero_order = FALSE, include_lower_order = FALSE, num_knots = NULL, - screen_knots = FALSE) { + screen_knots = FALSE, + algorithm = "kmeans") { if (!is.matrix(x)) { x <- as.matrix(x) } @@ -235,13 +239,23 @@ enumerate_basis <- function(x, } if (screen_knots) { - basis_list <- kmeans_knot_screen( - X = x, - bins = n_bin, - degree = degree, - smoothness_orders = smoothness_orders, - include_zero_order = include_zero_order, - include_lower_order = include_lower_order) + if (algorithm == "kmeans") { + basis_list <- kmeans_knot_screen( + X = x, + bins = n_bin, + degree = degree, + smoothness_orders = smoothness_orders, + include_zero_order = include_zero_order, + include_lower_order = include_lower_order) + } else if (algorithm == "pam") { + basis_list <- pam_knot_screen( + X = x, + bins = n_bin, + degree = degree, + smoothness_orders = smoothness_orders, + include_zero_order = include_zero_order, + include_lower_order = include_lower_order) + } return(basis_list) @@ -447,3 +461,78 @@ kmeans_knot_screen <- function(X, return(basis_list) } + +#' @title Screen knot-points using PAM clustering +#' +#' @description Using PAM clustering at each basis function level to screen +#' knot-points. An effective way to reduce the number of basis function and +#' improve the computational speed while maintaining good predictive +#' performance of HAL fit. +#' +#' @param X A \code{numeric} vector or matrix of input. +#' @param bins A \code{numeric} scalar indicating the number of knot-points for +#' each basis function. +#' @param degree A \code{numeric} scalar indicating the degree of the basis +#' functions. +#' @param smoothness_orders Argument for `enumerate_basis()` +#' @param include_zero_order Argument for `enumerate_basis()` +#' @param include_lower_order Argument for `enumerate_basis()` +#' +#' @importFrom cluster pam +#' @importFrom data.table uniqueN +#' +#' @keywords internal +pam_knot_screen <- function(X, + bins, + degree, + smoothness_orders, + include_zero_order, + include_lower_order) { + + if (is.null(bins)) { + return(X) + } + + if (!is.matrix(X)) { + X <- as.matrix(X) + } + + # function to screen knot-points for given x (subset of X) + screen_knots <- function(x) { + if (nrow(unique(x)) <= bins) { + return(unique(x)) + } + if (ncol(x) == 1 & all(x %in% c(0, 1))) { + return(rep(0, length(x))) + } + if (bins == 1) { + return(rep(min(x), length(x))) + } + + # pam clustering of knot-points + pam_obj <- pam(x = x, k = bins, metric = "euclidean") + centroids <- pam_obj$medoids + + return(centroids) + } + + # generate column index sets up to max_degree interactions + var_idx_num <- utils::combn(1:ncol(X), degree, simplify = FALSE) + + # screen knot-points for each column index set + basis_list <- unlist(lapply(var_idx_num, function(var_idx) { + X_sub <- X[, var_idx, drop = FALSE] + suppressWarnings(knots <- screen_knots(X_sub)) + + # make basis list for the screened knot-points + basis_of_degree(x = knots, + degree = degree, + smoothness_orders = smoothness_orders, + include_zero_order = include_zero_order, + include_lower_order = include_lower_order, + x_col_idx = var_idx) + + }), recursive = FALSE) + + return(basis_list) +} diff --git a/man/basis_list_cols.Rd b/man/basis_list_cols.Rd index 839a30ce..a65db78f 100644 --- a/man/basis_list_cols.Rd +++ b/man/basis_list_cols.Rd @@ -9,7 +9,8 @@ basis_list_cols( x, smoothness_orders, include_zero_order, - include_lower_order = FALSE + include_lower_order = FALSE, + ignore_cols = FALSE ) } \arguments{ diff --git a/man/basis_of_degree.Rd b/man/basis_of_degree.Rd index 2bc576a4..5b5c6a9a 100644 --- a/man/basis_of_degree.Rd +++ b/man/basis_of_degree.Rd @@ -9,7 +9,8 @@ basis_of_degree( degree, smoothness_orders, include_zero_order, - include_lower_order + include_lower_order, + x_col_idx = NULL ) } \arguments{ diff --git a/man/enumerate_basis.Rd b/man/enumerate_basis.Rd index 4ce8176d..4d55bf12 100644 --- a/man/enumerate_basis.Rd +++ b/man/enumerate_basis.Rd @@ -10,7 +10,9 @@ enumerate_basis( smoothness_orders = rep(0, ncol(x)), include_zero_order = FALSE, include_lower_order = FALSE, - num_knots = NULL + num_knots = NULL, + screen_knots = FALSE, + algorithm = "kmeans" ) } \arguments{ @@ -49,6 +51,15 @@ knot points to be used for the kth degree basis functions. Specifically, for a kth degree basis function, which is the tensor product of k univariate basis functions, this determines the number of knot points to be used for each univariate basis function in the tensor product.} + +\item{screen_knots}{A \code{logical}, indicating whether to screen the knot +points for each basis function. If \code{TRUE}, the knot points are screened +using k-mean clustering. The number of knots for each degree of basis +function is determined by the \code{num_knots} argument.} + +\item{algorithm}{A \code{character} string specifying the algorithm to use +for knot point screening. The default is \code{"kmeans"}. The other option +is \code{"pam"}.} } \value{ A \code{list} of basis functions generated for all covariates and diff --git a/man/formula_helpers.Rd b/man/formula_helpers.Rd deleted file mode 100644 index 893a2de1..00000000 --- a/man/formula_helpers.Rd +++ /dev/null @@ -1,20 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/formula_hal9001.R -\name{formula_helpers} -\alias{formula_helpers} -\alias{fill_dots_helper} -\alias{fill_dots} -\title{Formula Helpers} -\usage{ -fill_dots_helper(var_names, .) - -fill_dots(var_names, .) -} -\arguments{ -\item{var_names}{A \code{character} vector of variable names.} - -\item{.}{Specification of variables for use in the formula.} -} -\description{ -Formula Helpers -} diff --git a/man/generate_all_rules.Rd b/man/generate_all_rules.Rd new file mode 100644 index 00000000..6fc2967a --- /dev/null +++ b/man/generate_all_rules.Rd @@ -0,0 +1,14 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/summary.R +\name{generate_all_rules} +\alias{generate_all_rules} +\title{Generates rules based on knot points of the fitted HAL basis functions with +non-zero coefficients.} +\usage{ +generate_all_rules(basis_list, coefs, X_colnames) +} +\description{ +Generates rules based on knot points of the fitted HAL basis functions with +non-zero coefficients. +} +\keyword{internal} diff --git a/man/kmeans_knot_screen.Rd b/man/kmeans_knot_screen.Rd new file mode 100644 index 00000000..ee6ef2db --- /dev/null +++ b/man/kmeans_knot_screen.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/make_basis.R +\name{kmeans_knot_screen} +\alias{kmeans_knot_screen} +\title{Screen knot-points using k-means clustering} +\usage{ +kmeans_knot_screen( + X, + bins, + degree, + smoothness_orders, + include_zero_order, + include_lower_order +) +} +\arguments{ +\item{X}{A \code{numeric} vector or matrix of input.} + +\item{bins}{A \code{numeric} scalar indicating the number of knot-points for +each basis function.} + +\item{degree}{A \code{numeric} scalar indicating the degree of the basis +functions.} + +\item{smoothness_orders}{Argument for \code{enumerate_basis()}} + +\item{include_zero_order}{Argument for \code{enumerate_basis()}} + +\item{include_lower_order}{Argument for \code{enumerate_basis()}} +} +\description{ +Using k-mean clustering at each basis function level to screen +knot-points. An effective way to reduce the number of basis function and +improve the computational speed while maintaining good predictive +performance of HAL fit. +} +\keyword{internal} diff --git a/man/pam_knot_screen.Rd b/man/pam_knot_screen.Rd new file mode 100644 index 00000000..81c726d0 --- /dev/null +++ b/man/pam_knot_screen.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/make_basis.R +\name{pam_knot_screen} +\alias{pam_knot_screen} +\title{Screen knot-points using PAM clustering} +\usage{ +pam_knot_screen( + X, + bins, + degree, + smoothness_orders, + include_zero_order, + include_lower_order +) +} +\arguments{ +\item{X}{A \code{numeric} vector or matrix of input.} + +\item{bins}{A \code{numeric} scalar indicating the number of knot-points for +each basis function.} + +\item{degree}{A \code{numeric} scalar indicating the degree of the basis +functions.} + +\item{smoothness_orders}{Argument for \code{enumerate_basis()}} + +\item{include_zero_order}{Argument for \code{enumerate_basis()}} + +\item{include_lower_order}{Argument for \code{enumerate_basis()}} +} +\description{ +Using PAM clustering at each basis function level to screen +knot-points. An effective way to reduce the number of basis function and +improve the computational speed while maintaining good predictive +performance of HAL fit. +} +\keyword{internal}