Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor moe::predict_smooth #221

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ego/examples/mopta08.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ fn mopta(x: &ArrayView2<f64>, indices: Option<&[usize]>) -> Array2<f64> {
path_exe.push(r"ego/examples");
path_exe.push(mopta_exe);

Command::new(path_exe)
let _ = Command::new(path_exe)
.spawn()
.expect("ls command failed to start");
.expect("ls command failed to start")
.wait();

std::thread::sleep(std::time::Duration::from_secs(1));
let y_i = get_output().unwrap();
Expand Down
3 changes: 0 additions & 3 deletions gp/src/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1574,9 +1574,6 @@ mod tests {
let _gp = Kriging::params()
.fit(&Dataset::new(xt.clone(), yt.clone()))
.expect("GP fit error");
// println!("theta = {}", gp.theta());
// let xtest = array![[0.1]];
// println!("pred({}) = {}", &xtest, gp.predict(&xtest).unwrap());
}

fn x2sinx(x: &Array2<f64>) -> Array2<f64> {
Expand Down
59 changes: 26 additions & 33 deletions moe/src/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ impl GpMixtureValidParams<f64> {
let errors = scale_factors.map(move |&factor| {
let gmx2 = gmx.clone();
let gmx2 = gmx2.heaviside_factor(factor);
let pred = predict_smooth(experts, &gmx2, xtest).unwrap();
let pred = predict_smooth(experts, &gmx2, xtest, ytest.ncols()).unwrap();
pred.sub(ytest).mapv(|x| x * x).sum().sqrt() / xtest.mapv(|x| x * x).sum().sqrt()
});

Expand Down Expand Up @@ -387,22 +387,17 @@ fn predict_smooth(
experts: &[Box<dyn FullGpSurrogate>],
gmx: &GaussianMixture<f64>,
points: &ArrayBase<impl Data<Elem = f64>, Ix2>,
ny: usize,
) -> Result<Array2<f64>> {
let probas = gmx.predict_probas(points);
let mut preds = Array1::<f64>::zeros(points.nrows());

Zip::from(&mut preds)
.and(points.rows())
.and(probas.rows())
.for_each(|y, x, p| {
let x = x.insert_axis(Axis(0));
let preds: Array1<f64> = experts
.iter()
.map(|gp| gp.predict(&x).unwrap()[[0, 0]])
.collect();
*y = (preds * p).sum();
});
Ok(preds.insert_axis(Axis(1)))
let preds: Array2<f64> = experts
.iter()
.enumerate()
.map(|(i, gp)| {
gp.predict(&points.view()).unwrap() * probas.column(i).to_owned().insert_axis(Axis(1))
})
.fold(Array2::zeros((points.nrows(), ny)), |acc, pred| acc + pred);
Ok(preds)
}

/// Mixture of gaussian process experts
Expand Down Expand Up @@ -597,7 +592,7 @@ impl GpMixture {
/// or another (ie responsabilities).
/// The smooth recombination of each cluster expert responsabilty is used to get the result.
pub fn predict_smooth(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array2<f64>> {
predict_smooth(&self.experts, &self.gmx, x)
predict_smooth(&self.experts, &self.gmx, x, self.output_dim())
}

/// Predict variances at a set of points `x` specified as (n, nx) matrix.
Expand All @@ -609,21 +604,19 @@ impl GpMixture {
x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
) -> Result<Array2<f64>> {
let probas = self.gmx.predict_probas(x);
let mut preds = Array1::<f64>::zeros(x.nrows());

Zip::from(&mut preds)
.and(x.rows())
.and(probas.rows())
.for_each(|y, x, p| {
let x = x.insert_axis(Axis(0));
let preds: Array1<f64> = self
.experts
.iter()
.map(|gp| gp.predict_var(&x).unwrap()[[0, 0]])
.collect();
*y = (preds * p * p).sum();
});
Ok(preds.insert_axis(Axis(1)))
let preds: Array2<f64> = self
.experts
.iter()
.enumerate()
.map(|(i, gp)| {
let p = probas.column(i).to_owned().insert_axis(Axis(1));
gp.predict_var(&x.view()).unwrap() * &p * &p
})
.fold(
Array2::zeros((x.nrows(), self.output_dim())),
|acc, pred| acc + pred,
);
Ok(preds)
}

/// Predict derivatives of the output at a set of points `x` specified as (n, nx) matrix.
Expand Down Expand Up @@ -731,7 +724,7 @@ impl GpMixture {
pub fn predict_hard(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array2<f64>> {
let clustering = self.gmx.predict(x);
trace!("Clustering {:?}", clustering);
let mut preds = Array2::zeros((x.nrows(), 1));
let mut preds = Array2::zeros((x.nrows(), self.output_dim()));
Zip::from(preds.rows_mut())
.and(x.rows())
.and(&clustering)
Expand All @@ -756,7 +749,7 @@ impl GpMixture {
) -> Result<Array2<f64>> {
let clustering = self.gmx.predict(x);
trace!("Clustering {:?}", clustering);
let mut variances = Array2::zeros((x.nrows(), 1));
let mut variances = Array2::zeros((x.nrows(), self.output_dim()));
Zip::from(variances.rows_mut())
.and(x.rows())
.and(&clustering)
Expand Down
Loading