From d31f0a4b708b2b8f023d9b1bd2b8af756fbed7d2 Mon Sep 17 00:00:00 2001 From: Jason Chulock Date: Tue, 4 Jun 2024 22:52:34 +0000 Subject: [PATCH] Add disqualification check to billing model predict() Signed-off-by: Jason Chulock --- eemeter/eemeter/models/billing/model.py | 11 ++++++++++- tests/test_derivatives.py | 6 ++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/eemeter/eemeter/models/billing/model.py b/eemeter/eemeter/models/billing/model.py index f5fa650a..1a9247f2 100644 --- a/eemeter/eemeter/models/billing/model.py +++ b/eemeter/eemeter/models/billing/model.py @@ -22,7 +22,10 @@ import numpy as np import pandas as pd -from eemeter.eemeter.common.exceptions import DataSufficiencyError +from eemeter.eemeter.common.exceptions import ( + DataSufficiencyError, + DisqualifiedModelError, +) from eemeter.eemeter.common.warnings import EEMeterWarning from eemeter.eemeter.models.billing.data import ( BillingBaselineData, @@ -66,9 +69,15 @@ def predict( self, reporting_data: Union[BillingBaselineData, BillingReportingData], aggregation=None, + ignore_disqualification=False, ): if not self.is_fitted: raise RuntimeError("Model must be fit before predictions can be made.") + + if self.disqualification and not ignore_disqualification: + raise DisqualifiedModelError( + "Attempting to predict using disqualified model without setting ignore_disqualification=True" + ) if not isinstance(reporting_data, (BillingBaselineData, BillingReportingData)): raise TypeError( diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 679355a7..5ae90196 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -248,7 +248,8 @@ def test_metered_savings_cdd_hdd_billing_single_record_baseline_data( reporting_meter_data_billing, reporting_temperature_data, is_electricity_data=True, - ) + ), + ignore_disqualification=True, ) assert list(results.columns) == [ "season", @@ -524,7 +525,8 @@ def test_metered_savings_model_single_record( reporting_meter_data_billing, reporting_temperature_data, is_electricity_data=True, - ) + ), + ignore_disqualification=True, ) assert list(results.columns) == [ "season",