Skip to content

Commit

Permalink
Merge pull request #227 from JamesHWade/tunable-target-weights-and-in…
Browse files Browse the repository at this point in the history
…itial

Tunable target weights and initial Fixes #223 and #222
  • Loading branch information
EmilHvitfeldt authored Aug 15, 2024
2 parents c23e089 + 1e305c6 commit 142836b
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 7 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# embed (development version)

* `step_umap()` has tunable `initial` and `target_weight` arguments. [#223](https://github.com/tidymodels/embed/issues/223), [#222](https://github.com/tidymodels/embed/issues/222))

# embed 1.1.4

## Improvements
Expand Down
6 changes: 4 additions & 2 deletions R/umap.R
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,15 @@ required_pkgs.step_umap <- function(x, ...) {
#' @rdname tunable_embed
tunable.step_umap <- function(x, ...) {
tibble::tibble(
name = c("num_comp", "neighbors", "min_dist", "learn_rate", "epochs"),
name = c("num_comp", "neighbors", "min_dist", "learn_rate", "epochs", "initial", "target_weight"),
call_info = list(
list(pkg = "dials", fun = "num_comp", range = c(1, 10)),
list(pkg = "dials", fun = "neighbors", range = c(5, 200)),
list(pkg = "dials", fun = "min_dist", range = c(-4, -0.69897)),
list(pkg = "dials", fun = "learn_rate"),
list(pkg = "dials", fun = "epochs", range = c(100, 700))
list(pkg = "dials", fun = "epochs", range = c(100, 700)),
list(pkg = "dials", fun = "initial_umap", values = c("spectral", "normlaplacian", "random", "lvrandom", "laplacian", "pca", "spca", "agspectral")),
list(pkg = "dials", fun = "target_weight", range = c(0, 1))
),
source = "recipe",
component = "step_umap",
Expand Down
4 changes: 3 additions & 1 deletion man/step_umap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions tests/testthat/test-umap.R
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ test_that("tunable", {
rec_param <- tunable.step_umap(rec$steps[[1]])
expect_equal(
rec_param$name,
c("num_comp", "neighbors", "min_dist", "learn_rate", "epochs")
c("num_comp", "neighbors", "min_dist", "learn_rate", "epochs", "initial", "target_weight")
)
expect_true(all(rec_param$source == "recipe"))
expect_true(is.list(rec_param$call_info))
expect_equal(nrow(rec_param), 5)
expect_equal(nrow(rec_param), 7L)
expect_equal(
names(rec_param),
c("name", "call_info", "source", "component", "component_id")
Expand Down Expand Up @@ -369,11 +369,13 @@ test_that("tunable is setup to works with extract_parameter_set_dials", {
neighbors = hardhat::tune(),
min_dist = hardhat::tune(),
learn_rate = hardhat::tune(),
epochs = hardhat::tune()
epochs = hardhat::tune(),
initial = hardhat::tune(),
target_weight = hardhat::tune()
)

params <- extract_parameter_set_dials(rec)

expect_s3_class(params, "parameters")
expect_identical(nrow(params), 5L)
expect_identical(nrow(params), 7L)
})

0 comments on commit 142836b

Please sign in to comment.