Skip to content

Commit

Permalink
Merge pull request #146 from realratchet/master
Browse files Browse the repository at this point in the history
Fix imputations
  • Loading branch information
realratchet authored Mar 8, 2024
2 parents 32fc7b3 + 09f4353 commit 3bdf7ff
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
21 changes: 16 additions & 5 deletions tablite/imputation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math

from tablite.base import BaseTable, Column
from tablite.utils import sub_cls_check
from tablite.utils import sub_cls_check, summary_statistics
from tablite.config import Config
from tablite import sort_utils

Expand Down Expand Up @@ -90,11 +90,11 @@ def imputation(T, targets, missing=None, method="carry forward", sources=None, t
methods = ["nearest neighbour", "mean", "mode", "carry forward"]

if method == "carry forward":
return carry_forward(T, targets, missing, tqdm=_tqdm, pbar=None)
return carry_forward(T, targets, missing, tqdm=tqdm, pbar=pbar)
elif method in {"mean", "mode"}:
return stats_method(T, targets, missing, method, tqdm=_tqdm, pbar=None)
return stats_method(T, targets, missing, method, tqdm=tqdm, pbar=pbar)
elif method == "nearest neighbour":
return nearest_neighbour(T, sources, missing, targets, tqdm=_tqdm, pbar=None)
return nearest_neighbour(T, sources, missing, targets, tqdm=tqdm, pbar=pbar)
else:
raise ValueError(f"method {method} not recognised amonst known methods: {list(methods)})")

Expand Down Expand Up @@ -136,7 +136,18 @@ def stats_method(T, targets, missing, method, tqdm=_tqdm, pbar=None):
if name in targets:
col = T.columns[name]
assert isinstance(col, Column)
stats = col.statistics()

hist_values, hist_counts = col.histogram()

for m in missing:
try:
idx = hist_values.index(m)
hist_counts[idx] = 0
except ValueError:
pass

stats = summary_statistics(hist_values, hist_counts)

new_value = stats[method]
col.replace(mapping={m: new_value for m in missing})
new[name] = col
Expand Down
2 changes: 1 addition & 1 deletion tablite/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
major, minor, patch = 2023, 10, 7
major, minor, patch = 2023, 10, 8
__version_info__ = (major, minor, patch)
__version__ = ".".join(str(i) for i in __version_info__)

0 comments on commit 3bdf7ff

Please sign in to comment.