Skip to content

Commit

Permalink
Add disqualification check to billing model predict()
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Chulock <[email protected]>
  • Loading branch information
jason-recurve committed Jun 4, 2024
1 parent b315053 commit d31f0a4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
11 changes: 10 additions & 1 deletion eemeter/eemeter/models/billing/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import numpy as np
import pandas as pd

from eemeter.eemeter.common.exceptions import DataSufficiencyError
from eemeter.eemeter.common.exceptions import (
DataSufficiencyError,
DisqualifiedModelError,
)
from eemeter.eemeter.common.warnings import EEMeterWarning
from eemeter.eemeter.models.billing.data import (
BillingBaselineData,
Expand Down Expand Up @@ -66,9 +69,15 @@ def predict(
self,
reporting_data: Union[BillingBaselineData, BillingReportingData],
aggregation=None,
ignore_disqualification=False,
):
if not self.is_fitted:
raise RuntimeError("Model must be fit before predictions can be made.")

if self.disqualification and not ignore_disqualification:
raise DisqualifiedModelError(
"Attempting to predict using disqualified model without setting ignore_disqualification=True"
)

if not isinstance(reporting_data, (BillingBaselineData, BillingReportingData)):
raise TypeError(
Expand Down
6 changes: 4 additions & 2 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def test_metered_savings_cdd_hdd_billing_single_record_baseline_data(
reporting_meter_data_billing,
reporting_temperature_data,
is_electricity_data=True,
)
),
ignore_disqualification=True,
)
assert list(results.columns) == [
"season",
Expand Down Expand Up @@ -524,7 +525,8 @@ def test_metered_savings_model_single_record(
reporting_meter_data_billing,
reporting_temperature_data,
is_electricity_data=True,
)
),
ignore_disqualification=True,
)
assert list(results.columns) == [
"season",
Expand Down

0 comments on commit d31f0a4

Please sign in to comment.