Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k-means clustering based knot-point screening #113

Open
wants to merge 3 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ LinkingTo:
Rcpp,
RcppEigen
VignetteBuilder: knitr
RoxygenNote: 7.2.0
RoxygenNote: 7.2.1
Roxygen: list(markdown = TRUE)
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import(Rcpp)
importFrom(Matrix,tcrossprod)
importFrom(Rcpp,sourceCpp)
importFrom(assertthat,assert_that)
importFrom(cluster,pam)
importFrom(data.table,`:=`)
importFrom(data.table,data.table)
importFrom(data.table,rbindlist)
importFrom(data.table,setorder)
importFrom(data.table,uniqueN)
importFrom(glmnet,cv.glmnet)
importFrom(glmnet,glmnet)
importFrom(methods,is)
Expand All @@ -37,6 +39,7 @@ importFrom(origami,validation)
importFrom(stats,aggregate)
importFrom(stats,as.formula)
importFrom(stats,coef)
importFrom(stats,kmeans)
importFrom(stats,median)
importFrom(stats,plogis)
importFrom(stats,predict)
Expand Down
224 changes: 211 additions & 13 deletions R/make_basis.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@
#' @return A \code{list} containing the basis functions generated from a set of
#' input columns.
basis_list_cols <- function(cols, x, smoothness_orders, include_zero_order,
include_lower_order = FALSE) {
include_lower_order = FALSE, ignore_cols = FALSE) {

# first, subset only to columns of interest
x_sub <- x[, cols, drop = FALSE]
if (ignore_cols) {
x_sub <- x
} else {
x_sub <- x[, cols, drop = FALSE]
}

# call Rcpp routine to produce the list of basis functions
basis_list <- make_basis_list(x_sub, cols, smoothness_orders)
Expand All @@ -51,12 +56,12 @@ basis_list_cols <- function(cols, x, smoothness_orders, include_zero_order,
new_smoothness_orders[col] <- new_smoothness_orders[col] - 1

return(basis_list_cols(cols, x, new_smoothness_orders,
include_zero_order,
include_lower_order = TRUE
include_zero_order,
include_lower_order = TRUE
))
})
basis_list <- union(basis_list, unlist(more_basis_list,
recursive = FALSE
recursive = FALSE
))
}
}
Expand Down Expand Up @@ -98,20 +103,30 @@ basis_list_cols <- function(cols, x, smoothness_orders, include_zero_order,
#' @return A \code{list} containing basis functions and cutoffs generated from
#' a set of input columns up to a particular pre-specified degree.
basis_of_degree <- function(x, degree, smoothness_orders, include_zero_order,
include_lower_order) {
include_lower_order, x_col_idx = NULL) {
# get dimensionality of input matrix
p <- ncol(x)

# the estimation problem is not defined when the following is violated
if (degree > p) stop("The problem is not defined for degree > p.")

# compute combinations of columns and generate a list of basis functions
all_cols <- utils::combn(p, degree)
if (is.null(x_col_idx)) {
all_cols <- utils::combn(p, degree)
} else {
# If x_col_idx is provided, that means knot-point screening algorithm is
# used, x could be a subset of original X.
# We need to make sure the column indices of x are correct.
# Note that in this case, ncol(x) must == degree.
all_cols <- matrix(x_col_idx)
}

all_basis_lists <- apply(all_cols, 2, basis_list_cols,
x = x,
smoothness_orders = smoothness_orders,
include_zero_order = include_zero_order,
include_lower_order = include_lower_order
x = x,
smoothness_orders = smoothness_orders,
include_zero_order = include_zero_order,
include_lower_order = include_lower_order,
ignore_cols = !is.null(x_col_idx)
)
basis_list <- unlist(all_basis_lists, recursive = FALSE)

Expand Down Expand Up @@ -147,6 +162,10 @@ basis_of_degree <- function(x, degree, smoothness_orders, include_zero_order,
#' @param include_lower_order A \code{logical}, like \code{include_zero_order},
#' except including all basis functions of lower smoothness degrees than
#' specified via \code{smoothness_orders}.
#' @param screen_knots A \code{logical}, indicating whether to screen the knot
#' points for each basis function. If \code{TRUE}, the knot points are screened
#' using k-mean clustering. The number of knots for each degree of basis
#' function is determined by the `num_knots` argument.
#' @param num_knots A vector of length \code{max_degree}, which determines how
#' granular the knot points to generate basis functions should be for each
#' degree of basis function. The first entry of \code{num_knots} determines
Expand All @@ -156,6 +175,9 @@ basis_of_degree <- function(x, degree, smoothness_orders, include_zero_order,
#' for a kth degree basis function, which is the tensor product of k
#' univariate basis functions, this determines the number of knot points to be
#' used for each univariate basis function in the tensor product.
#' @param algorithm A \code{character} string specifying the algorithm to use
#' for knot point screening. The default is \code{"kmeans"}. The other option
#' is \code{"pam"}.
#'
#' @export
#'
Expand Down Expand Up @@ -186,7 +208,9 @@ enumerate_basis <- function(x,
smoothness_orders = rep(0, ncol(x)),
include_zero_order = FALSE,
include_lower_order = FALSE,
num_knots = NULL) {
num_knots = NULL,
screen_knots = FALSE,
algorithm = "kmeans") {
if (!is.matrix(x)) {
x <- as.matrix(x)
}
Expand All @@ -213,7 +237,31 @@ enumerate_basis <- function(x,
} else {
n_bin <- num_knots[degree]
}
x <- quantizer(x, n_bin)

if (screen_knots) {
if (algorithm == "kmeans") {
basis_list <- kmeans_knot_screen(
X = x,
bins = n_bin,
degree = degree,
smoothness_orders = smoothness_orders,
include_zero_order = include_zero_order,
include_lower_order = include_lower_order)
} else if (algorithm == "pam") {
basis_list <- pam_knot_screen(
X = x,
bins = n_bin,
degree = degree,
smoothness_orders = smoothness_orders,
include_zero_order = include_zero_order,
include_lower_order = include_lower_order)
}

return(basis_list)

} else {
x <- quantizer(x, n_bin)
}
}
return(basis_of_degree(
x, degree, smoothness_orders, include_zero_order,
Expand Down Expand Up @@ -338,3 +386,153 @@ quantizer <- function(X, bins) {
}
return(quantizer(X))
}

#' @title Screen knot-points using k-means clustering
#'
#' @description Using k-mean clustering at each basis function level to screen
#' knot-points. An effective way to reduce the number of basis function and
#' improve the computational speed while maintaining good predictive
#' performance of HAL fit.
#'
#' @param X A \code{numeric} vector or matrix of input.
#' @param bins A \code{numeric} scalar indicating the number of knot-points for
#' each basis function.
#' @param degree A \code{numeric} scalar indicating the degree of the basis
#' functions.
#' @param smoothness_orders Argument for `enumerate_basis()`
#' @param include_zero_order Argument for `enumerate_basis()`
#' @param include_lower_order Argument for `enumerate_basis()`
#'
#' @importFrom stats kmeans
#' @importFrom data.table uniqueN
#'
#' @keywords internal
kmeans_knot_screen <- function(X,
bins,
degree,
smoothness_orders,
include_zero_order,
include_lower_order) {

if (is.null(bins)) {
return(X)
}

if (!is.matrix(X)) {
X <- as.matrix(X)
}

# function to screen knot-points for given x (subset of X)
screen_knots <- function(x) {
if (nrow(unique(x)) <= bins) {
return(unique(x))
}
if (ncol(x) == 1 & all(x %in% c(0, 1))) {
return(rep(0, length(x)))
}
if (bins == 1) {
return(rep(min(x), length(x)))
}

# k-means clustering of knot-points
k_means <- kmeans(x = x, centers = bins, algorithm = "MacQueen")
centroids <- k_means$centers

return(centroids)
}

# generate column index sets up to max_degree interactions
var_idx_num <- utils::combn(1:ncol(X), degree, simplify = FALSE)

# screen knot-points for each column index set
basis_list <- unlist(lapply(var_idx_num, function(var_idx) {
X_sub <- X[, var_idx, drop = FALSE]
suppressWarnings(knots <- screen_knots(X_sub))

# make basis list for the screened knot-points
basis_of_degree(x = knots,
degree = degree,
smoothness_orders = smoothness_orders,
include_zero_order = include_zero_order,
include_lower_order = include_lower_order,
x_col_idx = var_idx)

}), recursive = FALSE)

return(basis_list)
}

#' @title Screen knot-points using PAM clustering
#'
#' @description Using PAM clustering at each basis function level to screen
#' knot-points. An effective way to reduce the number of basis function and
#' improve the computational speed while maintaining good predictive
#' performance of HAL fit.
#'
#' @param X A \code{numeric} vector or matrix of input.
#' @param bins A \code{numeric} scalar indicating the number of knot-points for
#' each basis function.
#' @param degree A \code{numeric} scalar indicating the degree of the basis
#' functions.
#' @param smoothness_orders Argument for `enumerate_basis()`
#' @param include_zero_order Argument for `enumerate_basis()`
#' @param include_lower_order Argument for `enumerate_basis()`
#'
#' @importFrom cluster pam
#' @importFrom data.table uniqueN
#'
#' @keywords internal
pam_knot_screen <- function(X,
bins,
degree,
smoothness_orders,
include_zero_order,
include_lower_order) {

if (is.null(bins)) {
return(X)
}

if (!is.matrix(X)) {
X <- as.matrix(X)
}

# function to screen knot-points for given x (subset of X)
screen_knots <- function(x) {
if (nrow(unique(x)) <= bins) {
return(unique(x))
}
if (ncol(x) == 1 & all(x %in% c(0, 1))) {
return(rep(0, length(x)))
}
if (bins == 1) {
return(rep(min(x), length(x)))
}

# pam clustering of knot-points
pam_obj <- pam(x = x, k = bins, metric = "euclidean")
centroids <- pam_obj$medoids

return(centroids)
}

# generate column index sets up to max_degree interactions
var_idx_num <- utils::combn(1:ncol(X), degree, simplify = FALSE)

# screen knot-points for each column index set
basis_list <- unlist(lapply(var_idx_num, function(var_idx) {
X_sub <- X[, var_idx, drop = FALSE]
suppressWarnings(knots <- screen_knots(X_sub))

# make basis list for the screened knot-points
basis_of_degree(x = knots,
degree = degree,
smoothness_orders = smoothness_orders,
include_zero_order = include_zero_order,
include_lower_order = include_lower_order,
x_col_idx = var_idx)

}), recursive = FALSE)

return(basis_list)
}
3 changes: 1 addition & 2 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ predict.hal9001 <- function(object,
offset = NULL,
type = c("response", "link"),
...) {

family <- ifelse(inherits(object$family, "family"), object$family$family, object$family)

type <- match.arg(type)
Expand Down Expand Up @@ -89,7 +88,7 @@ predict.hal9001 <- function(object,
} else {
preds <- as.vector(Matrix::tcrossprod(
x = pred_x_basis,
y = object$coefs[-1]
y = matrix(object$coefs[-1], nrow = 1)
) + object$coefs[1])
}
} else {
Expand Down
3 changes: 2 additions & 1 deletion man/basis_list_cols.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/basis_of_degree.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion man/enumerate_basis.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading