From 113a85d0c0877f8ae2c4a03386ac891d3d3291cb Mon Sep 17 00:00:00 2001 From: relf Date: Thu, 12 Dec 2024 17:45:18 +0100 Subject: [PATCH 1/2] Refactor predict_smooth in moe --- moe/src/algorithm.rs | 59 +++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 33 deletions(-) diff --git a/moe/src/algorithm.rs b/moe/src/algorithm.rs index 274279d0..226ddaf0 100644 --- a/moe/src/algorithm.rs +++ b/moe/src/algorithm.rs @@ -346,7 +346,7 @@ impl GpMixtureValidParams { let errors = scale_factors.map(move |&factor| { let gmx2 = gmx.clone(); let gmx2 = gmx2.heaviside_factor(factor); - let pred = predict_smooth(experts, &gmx2, xtest).unwrap(); + let pred = predict_smooth(experts, &gmx2, xtest, ytest.ncols()).unwrap(); pred.sub(ytest).mapv(|x| x * x).sum().sqrt() / xtest.mapv(|x| x * x).sum().sqrt() }); @@ -387,22 +387,17 @@ fn predict_smooth( experts: &[Box], gmx: &GaussianMixture, points: &ArrayBase, Ix2>, + ny: usize, ) -> Result> { let probas = gmx.predict_probas(points); - let mut preds = Array1::::zeros(points.nrows()); - - Zip::from(&mut preds) - .and(points.rows()) - .and(probas.rows()) - .for_each(|y, x, p| { - let x = x.insert_axis(Axis(0)); - let preds: Array1 = experts - .iter() - .map(|gp| gp.predict(&x).unwrap()[[0, 0]]) - .collect(); - *y = (preds * p).sum(); - }); - Ok(preds.insert_axis(Axis(1))) + let preds: Array2 = experts + .iter() + .enumerate() + .map(|(i, gp)| { + gp.predict(&points.view()).unwrap() * probas.column(i).to_owned().insert_axis(Axis(1)) + }) + .fold(Array2::zeros((points.nrows(), ny)), |acc, pred| acc + pred); + Ok(preds) } /// Mixture of gaussian process experts @@ -597,7 +592,7 @@ impl GpMixture { /// or another (ie responsabilities). /// The smooth recombination of each cluster expert responsabilty is used to get the result. pub fn predict_smooth(&self, x: &ArrayBase, Ix2>) -> Result> { - predict_smooth(&self.experts, &self.gmx, x) + predict_smooth(&self.experts, &self.gmx, x, self.output_dim()) } /// Predict variances at a set of points `x` specified as (n, nx) matrix. @@ -609,21 +604,19 @@ impl GpMixture { x: &ArrayBase, Ix2>, ) -> Result> { let probas = self.gmx.predict_probas(x); - let mut preds = Array1::::zeros(x.nrows()); - - Zip::from(&mut preds) - .and(x.rows()) - .and(probas.rows()) - .for_each(|y, x, p| { - let x = x.insert_axis(Axis(0)); - let preds: Array1 = self - .experts - .iter() - .map(|gp| gp.predict_var(&x).unwrap()[[0, 0]]) - .collect(); - *y = (preds * p * p).sum(); - }); - Ok(preds.insert_axis(Axis(1))) + let preds: Array2 = self + .experts + .iter() + .enumerate() + .map(|(i, gp)| { + let p = probas.column(i).to_owned().insert_axis(Axis(1)); + gp.predict_var(&x.view()).unwrap() * &p * &p + }) + .fold( + Array2::zeros((x.nrows(), self.output_dim())), + |acc, pred| acc + pred, + ); + Ok(preds) } /// Predict derivatives of the output at a set of points `x` specified as (n, nx) matrix. @@ -731,7 +724,7 @@ impl GpMixture { pub fn predict_hard(&self, x: &ArrayBase, Ix2>) -> Result> { let clustering = self.gmx.predict(x); trace!("Clustering {:?}", clustering); - let mut preds = Array2::zeros((x.nrows(), 1)); + let mut preds = Array2::zeros((x.nrows(), self.output_dim())); Zip::from(preds.rows_mut()) .and(x.rows()) .and(&clustering) @@ -756,7 +749,7 @@ impl GpMixture { ) -> Result> { let clustering = self.gmx.predict(x); trace!("Clustering {:?}", clustering); - let mut variances = Array2::zeros((x.nrows(), 1)); + let mut variances = Array2::zeros((x.nrows(), self.output_dim())); Zip::from(variances.rows_mut()) .and(x.rows()) .and(&clustering) From 783f561a65a3b48f65cf058dd5978a8b735a5e31 Mon Sep 17 00:00:00 2001 From: relf Date: Thu, 12 Dec 2024 17:45:33 +0100 Subject: [PATCH 2/2] Cleanup --- ego/examples/mopta08.rs | 5 +++-- gp/src/algorithm.rs | 3 --- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/ego/examples/mopta08.rs b/ego/examples/mopta08.rs index 6dbbc27a..0346c36f 100644 --- a/ego/examples/mopta08.rs +++ b/ego/examples/mopta08.rs @@ -191,9 +191,10 @@ fn mopta(x: &ArrayView2, indices: Option<&[usize]>) -> Array2 { path_exe.push(r"ego/examples"); path_exe.push(mopta_exe); - Command::new(path_exe) + let _ = Command::new(path_exe) .spawn() - .expect("ls command failed to start"); + .expect("ls command failed to start") + .wait(); std::thread::sleep(std::time::Duration::from_secs(1)); let y_i = get_output().unwrap(); diff --git a/gp/src/algorithm.rs b/gp/src/algorithm.rs index 68d71475..c93dfd19 100644 --- a/gp/src/algorithm.rs +++ b/gp/src/algorithm.rs @@ -1574,9 +1574,6 @@ mod tests { let _gp = Kriging::params() .fit(&Dataset::new(xt.clone(), yt.clone())) .expect("GP fit error"); - // println!("theta = {}", gp.theta()); - // let xtest = array![[0.1]]; - // println!("pred({}) = {}", &xtest, gp.predict(&xtest).unwrap()); } fn x2sinx(x: &Array2) -> Array2 {