Skip to content

Commit

Permalink
Merge pull request #234 from tidymodels/type-checkers
Browse files Browse the repository at this point in the history
Add rlang type checkers
  • Loading branch information
EmilHvitfeldt authored Nov 8, 2024
2 parents c3b2b5d + e361ae8 commit 9c10054
Show file tree
Hide file tree
Showing 39 changed files with 1,555 additions and 81 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Imports:
generics (>= 0.1.0),
lifecycle,
purrr,
rlang (>= 0.4.10),
rlang (>= 1.1.0),
rsample,
stats,
tibble,
Expand Down
3 changes: 3 additions & 0 deletions R/collapse_cart.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ step_collapse_cart <-
id = rand_id("step_collapse_cart")) {
recipes_pkg_check(required_pkgs.step_discretize_cart())

check_number_decimal(cost_complexity, min = 0)
check_number_whole(min_n, min = 1)

add_step(
recipe,
step_collapse_cart_new(
Expand Down
5 changes: 2 additions & 3 deletions R/collapse_stringdist.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ step_collapse_stringdist <-
columns = NULL,
skip = FALSE,
id = rand_id("collapse_stringdist")) {
if (is.null(distance)) {
cli::cli_abort("The {.arg distance} argument must be set.")
}
check_number_decimal(distance, min = 0)
check_string(method)

add_step(
recipe,
Expand Down
6 changes: 5 additions & 1 deletion R/discretize_cart.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ cart_binning <- function(predictor, term, outcome, cost_complexity, tree_depth,
prep.step_discretize_cart <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_decimal(x$cost_complexity, min = 0, arg = "cost_complexity")
check_number_decimal(x$tree_depth, min = 0, arg = "tree_depth")
check_number_decimal(x$min_n, min = 0, arg = "min_n")

wts <- get_case_weights(info, training)
were_weights_used <- are_weights_used(wts)
if (isFALSE(were_weights_used)) {
Expand Down Expand Up @@ -261,7 +265,7 @@ bake.step_discretize_cart <- function(object, new_data, ...) {
dig.lab = 4
)

check_name(binned_data, new_data, object)
recipes::check_name(binned_data, new_data, object)
new_data <- binned_data
}
}
Expand Down
8 changes: 7 additions & 1 deletion R/discretize_xgb.R
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,12 @@ xgb_binning <- function(df, outcome, predictor, sample_val, learn_rate,
prep.step_discretize_xgb <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_decimal(x$sample_val, min = 0, max = 1, arg = "sample_val")
check_number_decimal(x$learn_rate, min = 0, arg = "learn_rate")
check_number_whole(x$num_breaks, min = 0, arg = "num_breaks")
check_number_whole(x$tree_depth, min = 0, arg = "tree_depth")
check_number_whole(x$min_n, min = 0, arg = "min_n")

wts <- get_case_weights(info, training)
were_weights_used <- are_weights_used(wts)
if (isFALSE(were_weights_used) || is.null(wts)) {
Expand Down Expand Up @@ -496,7 +502,7 @@ bake.step_discretize_xgb <- function(object, new_data, ...) {
dig.lab = 4
)

check_name(binned_data, new_data, object)
recipes::check_name(binned_data, new_data, object)
new_data <- binned_data
}
}
Expand Down
5 changes: 4 additions & 1 deletion R/embed.R
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ step_embed_new <-
prep.step_embed <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_whole(x$num_terms, min = 0, arg = "num_terms")
check_number_whole(x$hidden_units, min = 0, arg = "hidden_units")

if (length(col_names) > 0) {
check_type(training[, col_names], types = c("string", "factor", "ordered"))
y_name <- recipes_eval_select(x$outcome, training, info)
Expand Down Expand Up @@ -429,7 +432,7 @@ bake.step_embed <- function(object, new_data, ...) {
prefix = col_name
)

tmp <- check_name(tmp, new_data, object, names(tmp))
tmp <- recipes::check_name(tmp, new_data, object, names(tmp))

new_data <- vec_cbind(new_data, tmp)
}
Expand Down
4 changes: 3 additions & 1 deletion R/feature_hash.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ step_feature_hash_new <-
prep.step_feature_hash <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_whole(x$num_hash, min = 0, arg = "num_hash")

if (length(col_names) > 0) {
check_type(training[, col_names], types = c("string", "factor", "ordered"))
}
Expand Down Expand Up @@ -222,7 +224,7 @@ bake.step_feature_hash <- function(object, new_data, ...) {
object$num_hash
)

new_cols <- check_name(new_cols, new_data, object, names(new_cols))
new_cols <- recipes::check_name(new_cols, new_data, object, names(new_cols))

new_data <- vec_cbind(new_data, new_cols)

Expand Down
Loading

0 comments on commit 9c10054

Please sign in to comment.