Skip to content

Commit

Permalink
Fix Egor solver best iter computation (#89)
Browse files Browse the repository at this point in the history
* Fix best iter computation

* Bump ego, egobox 0.8.2

* Fix test

* Adjust test for reproducibility
  • Loading branch information
relf authored Mar 29, 2023
1 parent 52480f6 commit f4da390
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 13 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "egobox"
version = "0.8.1"
version = "0.8.2"
authors = ["Rémi Lafage <[email protected]>"]
edition = "2018"
description = "A toolbox for efficient global optimization"
Expand Down Expand Up @@ -28,7 +28,7 @@ blas = ["ndarray/blas", "egobox-gp/blas", "egobox-moe/blas", "egobox-ego/blas"]
egobox-doe = { version = "0.8.1", path="./doe" }
egobox-gp = { version = "0.8.1", path="./gp" }
egobox-moe = { version = "0.8.1", path="./moe", features=["persistent"] }
egobox-ego = { version = "0.8.1", path="./ego", features=["persistent"] }
egobox-ego = { version = "0.8.2", path="./ego", features=["persistent"] }

linfa = { version = "0.6.1", default-features = false }

Expand Down
2 changes: 1 addition & 1 deletion ego/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "egobox-ego"
version = "0.8.1"
version = "0.8.2"
authors = ["Rémi Lafage <[email protected]>"]
edition = "2018"
description = "A library for efficient global optimization"
Expand Down
1 change: 1 addition & 0 deletions ego/src/egor_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ where
.data((x_data, y_data))
.clusterings(clusterings)
.sampling(sampling);
initial_state.doe_size = doe.nrows();
initial_state.max_iters = self.n_iter as u64;
initial_state.added = doe.nrows();
initial_state.no_point_added_retries = no_point_added_retries;
Expand Down
23 changes: 15 additions & 8 deletions ego/src/egor_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ pub struct EgorState<F: Float> {
/// Optimization status
pub termination_status: TerminationStatus,

/// Initial doe size
pub(crate) doe_size: usize,
/// Number of added points
pub(crate) added: usize,
/// Previous number of added points
Expand Down Expand Up @@ -401,6 +403,7 @@ where
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,

doe_size: 0,
added: 0,
prev_added: 0,
no_point_added_retries: MAX_POINT_ADDITION_RETRY,
Expand All @@ -426,6 +429,7 @@ where
///
/// // Simulating a new, better parameter vector
/// let mut state = state.data((array![[1.0f64], [2.0f64]], array![[10.0],[5.0]]));
/// state.iter = 2;
/// state.param = Some(array![2.0f64]);
/// state.cost = Some(array![5.0]);
///
Expand All @@ -438,7 +442,7 @@ where
/// assert!(state.is_best());
/// ```
fn update(&mut self) {
// TODO: better implementation should track only track
// TODO: better implementation should only track
// current and best index in data and compare just them
// without finding best in data each time
let data = self.data.as_ref();
Expand All @@ -449,15 +453,18 @@ where
}
Some((x_data, y_data)) => {
let best_index = find_best_result_index(y_data, self.cstr_tol);
let best_iter = best_index.saturating_sub(self.doe_size) as u64 + 1;

let param = x_data.row(best_index).to_owned();
std::mem::swap(&mut self.prev_best_param, &mut self.best_param);
self.best_param = Some(param);
if best_iter > self.last_best_iter {
let param = x_data.row(best_index).to_owned();
std::mem::swap(&mut self.prev_best_param, &mut self.best_param);
self.best_param = Some(param);

let cost = y_data.row(best_index).to_owned();
std::mem::swap(&mut self.prev_best_cost, &mut self.best_cost);
self.best_cost = Some(cost);
self.last_best_iter = self.iter;
let cost = y_data.row(best_index).to_owned();
std::mem::swap(&mut self.prev_best_cost, &mut self.best_cost);
self.best_cost = Some(cost);
self.last_best_iter = best_iter;
}
}
};
}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python-source="python"

[tool.poetry]
name = "egobox"
version = "0.8.1"
version = "0.8.2"
description = "Python binding for egobox EGO optimizer written in Rust"
authors = ["Rémi Lafage <[email protected]>"]

Expand Down
2 changes: 1 addition & 1 deletion python/egobox/tests/test_mixintegor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestMixintEgx(unittest.TestCase):
def test_xsinx(self):
xtypes = [egx.XSpec(egx.XType(egx.XType.INT), [0.0, 25.0])]

egor = egx.Egor(xsinx, xtypes, seed=42, n_doe=7)
egor = egx.Egor(xsinx, xtypes, seed=42, n_doe=3)
res = egor.minimize(n_iter=10)
print(f"Optimization f={res.y_opt} at {res.x_opt}")
self.assertAlmostEqual(-15.125, res.y_opt[0], delta=5e-3)
Expand Down

0 comments on commit f4da390

Please sign in to comment.