From f14d5894f3c9a3ae1870ce834b531ad1a8178287 Mon Sep 17 00:00:00 2001 From: Steven Dahdah Date: Thu, 26 Sep 2024 14:54:12 -0400 Subject: [PATCH] Fix compatibility with multioutput=raw_values --- pykoop/koopman_pipeline.py | 2 +- tests/test_koopman_pipeline.py | 64 ++++++++++++++++++++++++++++++++-- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/pykoop/koopman_pipeline.py b/pykoop/koopman_pipeline.py index 3dafbb4..821c79d 100644 --- a/pykoop/koopman_pipeline.py +++ b/pykoop/koopman_pipeline.py @@ -3425,7 +3425,7 @@ def score_trajectory( else: score = regression_metric(**regression_metric_args) # Return error score if score is not finite - if not np.isfinite(score): + if not np.all(np.isfinite(score)): if isinstance(error_score, str): raise ValueError( 'Prediction diverged or error occured while scoring.') diff --git a/tests/test_koopman_pipeline.py b/tests/test_koopman_pipeline.py index bf939a8..cb39b60 100644 --- a/tests/test_koopman_pipeline.py +++ b/tests/test_koopman_pipeline.py @@ -323,8 +323,8 @@ class TestKoopmanPipelineScore: @pytest.mark.parametrize( 'X_predicted, X_expected, n_steps, discount_factor, ' - 'regression_metric, error_score, min_samples, episode_feature, ' - 'score_exp', + 'regression_metric, regression_metric_kw, error_score, min_samples, ' + 'episode_feature, score_exp', [ ( np.array([ @@ -338,11 +338,32 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_squared_error', + None, np.nan, 1, False, 0, ), + ( + np.array([ + [1, 2, 3, 4], + [2, 3, 3, 2], + ]).T, + np.array([ + [1, 2, 3, 4], + [2, 3, 3, 2], + ]).T, + None, + 1, + 'neg_mean_squared_error', + { + 'multioutput': 'raw_values', + }, + np.nan, + 1, + False, + np.array([0, 0]), + ), ( np.array([ [1, 2], @@ -355,6 +376,27 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_squared_error', + { + 'multioutput': 'raw_values', + }, + np.nan, + 1, + False, + -np.array([2**2, 1]), + ), + ( + np.array([ + [1, 2], + [2, 3], + ]).T, + np.array([ + [1, 4], + [2, 2], + ]).T, + None, + 1, + 'neg_mean_squared_error', + None, np.nan, 1, False, @@ -370,6 +412,7 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_squared_error', + None, np.nan, 1, False, @@ -385,6 +428,7 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_absolute_error', + None, np.nan, 1, False, @@ -400,6 +444,7 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_squared_error', + None, np.nan, 2, False, @@ -415,6 +460,7 @@ class TestKoopmanPipelineScore: 2, 1, 'neg_mean_squared_error', + None, np.nan, 1, False, @@ -430,6 +476,7 @@ class TestKoopmanPipelineScore: None, 0.5, 'neg_mean_squared_error', + None, np.nan, 1, False, @@ -448,6 +495,7 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_squared_error', + None, np.nan, 1, True, @@ -465,6 +513,7 @@ class TestKoopmanPipelineScore: 1, 1, 'neg_mean_squared_error', + None, np.nan, 1, True, @@ -482,6 +531,7 @@ class TestKoopmanPipelineScore: None, 0.5, 'neg_mean_squared_error', + None, np.nan, 1, True, @@ -499,6 +549,7 @@ class TestKoopmanPipelineScore: 1, 0.5, 'neg_mean_squared_error', + None, np.nan, 1, True, @@ -516,6 +567,7 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_squared_error', + None, np.nan, 1, False, @@ -533,6 +585,7 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_squared_error', + None, -100, 1, False, @@ -550,6 +603,7 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_squared_error', + None, 'raise', 1, False, @@ -565,6 +619,7 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_squared_error', + None, -100, 1, False, @@ -580,6 +635,7 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_squared_error', + None, 'raise', 1, False, @@ -596,6 +652,7 @@ class TestKoopmanPipelineScore: None, 1, 'neg_mean_squared_error', + None, -100, 1, False, @@ -610,6 +667,7 @@ def test_score_trajectory( n_steps, discount_factor, regression_metric, + regression_metric_kw, error_score, min_samples, episode_feature, @@ -623,6 +681,7 @@ def test_score_trajectory( n_steps=n_steps, discount_factor=discount_factor, regression_metric=regression_metric, + regression_metric_kw=regression_metric_kw, error_score=error_score, min_samples=min_samples, episode_feature=episode_feature, @@ -634,6 +693,7 @@ def test_score_trajectory( n_steps=n_steps, discount_factor=discount_factor, regression_metric=regression_metric, + regression_metric_kw=regression_metric_kw, error_score=error_score, min_samples=min_samples, episode_feature=episode_feature,