diff --git a/spiderexpress/strategies/spikyball.py b/spiderexpress/strategies/spikyball.py index 2727b44..e61d722 100644 --- a/spiderexpress/strategies/spikyball.py +++ b/spiderexpress/strategies/spikyball.py @@ -96,7 +96,10 @@ def calc_norm(source: pd.Series, edge: pd.Series, target: pd.Series) -> float: float : the normalization constant """ - return sum(source.fillna(1) * edge.fillna(1) * target.fillna(1)) + if any(source.isna(), edge.isna(), target.isna()): + log.warning("Input contains NaN values which will be replaced with 1.") + norm_const = (source.fillna(1) * edge.fillna(1) * target.fillna(1)).fillna(1).sum() + return norm_const def calc_prob(table: pd.DataFrame, params: ProbabilityConfiguration) -> pd.Series: