Skip to content

Commit

Permalink
Panic if training output is not one-dimensional (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
relf authored Nov 21, 2024
1 parent 52e7253 commit 323210b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 1 deletion.
19 changes: 19 additions & 0 deletions gp/src/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,12 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>, D: Data<Elem
) -> Result<Self::Object> {
let x = dataset.records();
let y = dataset.targets();
if y.ncols() > 1 {
panic!(
"Multiple outputs not handled, a one-dimensional column vector \
as training output data is expected"
);
}
if let Some(d) = self.kpls_dim() {
if *d > x.ncols() {
return Err(GpError::InvalidValueError(format!(
Expand Down Expand Up @@ -1560,6 +1566,19 @@ mod tests {
assert_abs_diff_eq!(*gp.theta().to_vec(), expected);
}

#[test]
#[should_panic]
fn test_multiple_outputs() {
let xt = array![[0.0], [1.0], [2.0], [3.0], [4.0]];
let yt = array![[0.0, 10.0], [1.0, -3.], [1.5, 1.5], [0.9, 1.0], [1.0, 0.0]];
let _gp = Kriging::params()
.fit(&Dataset::new(xt.clone(), yt.clone()))
.expect("GP fit error");
// println!("theta = {}", gp.theta());
// let xtest = array![[0.1]];
// println!("pred({}) = {}", &xtest, gp.predict(&xtest).unwrap());
}

fn x2sinx(x: &Array2<f64>) -> Array2<f64> {
(x * x) * (x).mapv(|v| v.sin())
}
Expand Down
25 changes: 24 additions & 1 deletion gp/src/sparse_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,12 @@ impl<F: Float, Corr: CorrelationModel<F>, D: Data<Elem = F> + Sync>
) -> Result<Self::Object> {
let x = dataset.records();
let y = dataset.targets();
if y.ncols() > 1 {
panic!(
"Multiple outputs not handled, a one-dimensional column vector \
as training output data is expected"
);
}
if let Some(d) = self.kpls_dim() {
if *d > x.ncols() {
return Err(GpError::InvalidValueError(format!(
Expand Down Expand Up @@ -838,7 +844,7 @@ mod tests {
use super::*;

use approx::assert_abs_diff_eq;
use ndarray::Array;
use ndarray::{concatenate, Array};
// use ndarray_npy::{read_npy, write_npy};
use ndarray_npy::write_npy;
use ndarray_rand::rand::SeedableRng;
Expand Down Expand Up @@ -1021,4 +1027,21 @@ mod tests {

save_data(&xt, &yt, &z, &xplot, &sgp_vals, &sgp_vars);
}

#[test]
#[should_panic]
fn test_multiple_outputs() {
let mut rng = Xoshiro256Plus::seed_from_u64(42);
// Generate training data
let nt = 200;
// Variance of the gaussian noise on our training data
let eta2: f64 = 0.01;
let (xt, yt) = make_test_data(nt, eta2, &mut rng);
let yt = concatenate(Axis(1), &[yt.view(), yt.view()]).unwrap();
let n_inducings = 30;

let _sgp = SparseKriging::params(Inducings::Randomized(n_inducings))
.fit(&Dataset::new(xt.clone(), yt.clone()))
.expect("GP fitted");
}
}
8 changes: 8 additions & 0 deletions python/egobox/tests/test_gpmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ def test_kpls_griewank(self):
error = np.linalg.norm(y_pred - y_test) / np.linalg.norm(y_test)
print(" RMS error: " + str(error))

def test_multi_outputs_exception(self):
self.xt = np.array([[0.0, 1.0, 2.0, 3.0, 4.0]]).T
self.yt = np.array(
[[0.0, 10.0], [1.0, -3.0], [1.5, 1.5], [0.9, 1.0], [1.0, 0.0]]
)
with self.assertRaises(BaseException):
egx.Gpx.builder().fit(self.xt, self.yt)


if __name__ == "__main__":
unittest.main()
19 changes: 19 additions & 0 deletions python/egobox/tests/test_sgpmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ def test_sgp_random(self):
print(elapsed)
print(sgp)

def test_sgp_multi_outputs_exception(self):
# random generator for reproducibility
rng = np.random.RandomState(0)

# Generate training data
nt = 200
# Variance of the gaussian noise on our trainingg data
eta2 = [0.01]
gaussian_noise = rng.normal(loc=0.0, scale=np.sqrt(eta2), size=(nt, 1))
xt = 2 * rng.rand(nt, 1) - 1
yt = f_obj(xt) + gaussian_noise
yt = np.hstack((yt, yt))

# Pick inducing points randomly in training data
n_inducing = 30

with self.assertRaises(BaseException):
egx.SparseGpMix(nz=n_inducing, seed=0).fit(xt, yt)


if __name__ == "__main__":
unittest.main()

0 comments on commit 323210b

Please sign in to comment.