diff --git a/gp/src/algorithm.rs b/gp/src/algorithm.rs index 8a70c618..e0b8a3f0 100644 --- a/gp/src/algorithm.rs +++ b/gp/src/algorithm.rs @@ -734,6 +734,12 @@ impl, Corr: CorrelationModel, D: Data Result { let x = dataset.records(); let y = dataset.targets(); + if y.ncols() > 1 { + panic!( + "Multiple outputs not handled, a one-dimensional column vector \ + as training output data is expected" + ); + } if let Some(d) = self.kpls_dim() { if *d > x.ncols() { return Err(GpError::InvalidValueError(format!( @@ -1560,6 +1566,19 @@ mod tests { assert_abs_diff_eq!(*gp.theta().to_vec(), expected); } + #[test] + #[should_panic] + fn test_multiple_outputs() { + let xt = array![[0.0], [1.0], [2.0], [3.0], [4.0]]; + let yt = array![[0.0, 10.0], [1.0, -3.], [1.5, 1.5], [0.9, 1.0], [1.0, 0.0]]; + let _gp = Kriging::params() + .fit(&Dataset::new(xt.clone(), yt.clone())) + .expect("GP fit error"); + // println!("theta = {}", gp.theta()); + // let xtest = array![[0.1]]; + // println!("pred({}) = {}", &xtest, gp.predict(&xtest).unwrap()); + } + fn x2sinx(x: &Array2) -> Array2 { (x * x) * (x).mapv(|v| v.sin()) } diff --git a/gp/src/sparse_algorithm.rs b/gp/src/sparse_algorithm.rs index 0877fcb9..9f41fc68 100644 --- a/gp/src/sparse_algorithm.rs +++ b/gp/src/sparse_algorithm.rs @@ -418,6 +418,12 @@ impl, D: Data + Sync> ) -> Result { let x = dataset.records(); let y = dataset.targets(); + if y.ncols() > 1 { + panic!( + "Multiple outputs not handled, a one-dimensional column vector \ + as training output data is expected" + ); + } if let Some(d) = self.kpls_dim() { if *d > x.ncols() { return Err(GpError::InvalidValueError(format!( @@ -838,7 +844,7 @@ mod tests { use super::*; use approx::assert_abs_diff_eq; - use ndarray::Array; + use ndarray::{concatenate, Array}; // use ndarray_npy::{read_npy, write_npy}; use ndarray_npy::write_npy; use ndarray_rand::rand::SeedableRng; @@ -1021,4 +1027,21 @@ mod tests { save_data(&xt, &yt, &z, &xplot, &sgp_vals, &sgp_vars); } + + #[test] + #[should_panic] + fn test_multiple_outputs() { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + // Generate training data + let nt = 200; + // Variance of the gaussian noise on our training data + let eta2: f64 = 0.01; + let (xt, yt) = make_test_data(nt, eta2, &mut rng); + let yt = concatenate(Axis(1), &[yt.view(), yt.view()]).unwrap(); + let n_inducings = 30; + + let _sgp = SparseKriging::params(Inducings::Randomized(n_inducings)) + .fit(&Dataset::new(xt.clone(), yt.clone())) + .expect("GP fitted"); + } } diff --git a/python/egobox/tests/test_gpmix.py b/python/egobox/tests/test_gpmix.py index 3ae3ea7a..76cc7dbb 100644 --- a/python/egobox/tests/test_gpmix.py +++ b/python/egobox/tests/test_gpmix.py @@ -119,6 +119,14 @@ def test_kpls_griewank(self): error = np.linalg.norm(y_pred - y_test) / np.linalg.norm(y_test) print(" RMS error: " + str(error)) + def test_multi_outputs_exception(self): + self.xt = np.array([[0.0, 1.0, 2.0, 3.0, 4.0]]).T + self.yt = np.array( + [[0.0, 10.0], [1.0, -3.0], [1.5, 1.5], [0.9, 1.0], [1.0, 0.0]] + ) + with self.assertRaises(BaseException): + egx.Gpx.builder().fit(self.xt, self.yt) + if __name__ == "__main__": unittest.main() diff --git a/python/egobox/tests/test_sgpmix.py b/python/egobox/tests/test_sgpmix.py index 4b202111..a838f298 100644 --- a/python/egobox/tests/test_sgpmix.py +++ b/python/egobox/tests/test_sgpmix.py @@ -60,6 +60,25 @@ def test_sgp_random(self): print(elapsed) print(sgp) + def test_sgp_multi_outputs_exception(self): + # random generator for reproducibility + rng = np.random.RandomState(0) + + # Generate training data + nt = 200 + # Variance of the gaussian noise on our trainingg data + eta2 = [0.01] + gaussian_noise = rng.normal(loc=0.0, scale=np.sqrt(eta2), size=(nt, 1)) + xt = 2 * rng.rand(nt, 1) - 1 + yt = f_obj(xt) + gaussian_noise + yt = np.hstack((yt, yt)) + + # Pick inducing points randomly in training data + n_inducing = 30 + + with self.assertRaises(BaseException): + egx.SparseGpMix(nz=n_inducing, seed=0).fit(xt, yt) + if __name__ == "__main__": unittest.main()