Skip to content

Commit

Permalink
Merge pull request #152 from omenSi/master
Browse files Browse the repository at this point in the history
fix unix_sort
  • Loading branch information
realratchet authored Mar 18, 2024
2 parents 42fae2c + f2ba5b9 commit c353722
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 8 deletions.
13 changes: 8 additions & 5 deletions tablite/imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,22 @@ def nearest_neighbour(T, sources, missing, targets, tqdm=_tqdm, pbar=None):
values = [(v, k) for k, v in values.items()]
values.sort()
values = [k for _, k in values]

d = sort_utils.HashDict()
n = len([v for v in values if v not in missing])
d = {v: i / n if v not in missing else math.inf for i, v in enumerate(values)}
for i, v in enumerate(values):
d[v] = i / n if v not in missing else math.inf
normalised_values[name] = [d[v] for v in T[name]]
norm_index[name] = d
values.clear()

missing_value_index = T.index(*targets)
missing_value_index = {k: v for k, v in missing_value_index.items() if missing.intersection(set(k))} # strip out all that do not have missings.

ranks = set()
for k, v in missing_value_index.items():
ranks.update(set(k))
ranks = sort_utils.HashDict()
for k in missing_value_index.keys():
for vv in k:
ranks[vv] = True
ranks = ranks.keys()
item_order = sort_utils.unix_sort(list(ranks))
new_order = {tuple(item_order[i] for i in k): k for k in missing_value_index.keys()}

Expand Down
41 changes: 40 additions & 1 deletion tablite/sort_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Iterator
from datetime import datetime, date, time, timedelta
from pyuca import Collator
from tablite.datatypes import numpy_to_python


uca_collator = Collator()
Expand Down Expand Up @@ -184,8 +186,10 @@ def unix_sort(values, reverse=False):
text_code = _unix_typecodes[str]
text = [(text_code, ix, v) for ix, v in enumerate(text)]

d = HashDict()
L = non_text + text
d = {value: ix for ix, (_, _, value) in enumerate(L)}
for ix, (_, _, value) in enumerate(L):
d[value] = ix
return d


Expand Down Expand Up @@ -258,3 +262,38 @@ def rank(values, reverse, mode):
raise ValueError(f"{mode} not in list of modes: {list(modes)}")
f = modes.get(mode)
return f(values, reverse)


class HashDict(dict):
"""
This class is just a nicity syntatic sugar for debugging.
Function identically to regular dictionary, just uses tupled key.
"""

def _get_hash(self, key):
key = numpy_to_python(key)
return (type(key), key)

def items(self):
return [(k, v) for (_, k), v in super().items()]

def keys(self):
return [k for (_, k) in super().keys()]

def __iter__(self) -> Iterator:
return (k for (_, k) in super().keys())

def __getitem__(self, key):
return super().__getitem__(self._get_hash(key))

def __setitem__(self, key, value):
return super().__setitem__(self._get_hash(key), value)

def __contains__(self, key) -> bool:
return super().__contains__(self._get_hash(key))

def __delitem__(self, key):
return super().__delitem__(self._get_hash(key))

def __repr__(self) -> str:
return '{' + ", ".join([f"{k}: {v}" for (_, k), v in self.items()]) + '}'
2 changes: 1 addition & 1 deletion tablite/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
major, minor, patch = 2023, 10, 12
major, minor, patch = 2023, 10, 13
__version_info__ = (major, minor, patch)
__version__ = ".".join(str(i) for i in __version_info__)
22 changes: 21 additions & 1 deletion tests/test_sort.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tablite import Table
from datetime import datetime
from numpy import datetime64

from tablite.sort_utils import unix_sort

def test_sort():
t = Table(columns={"A": [4, 3, 2, 1], "B": [2, 2, 1, 1], "C": ["a", "d", "c", "b"]})
Expand Down Expand Up @@ -50,3 +50,23 @@ def test_sort_datetime():

t = Table(columns={"A": [datetime64("2005"), datetime64(datetime.now())], "B": [2, 2]})
t.sort({"A": False})


def test_unix_sort():
d = unix_sort([True, True, True, 0, 1, 1.0, False, 2])
assert False in d
assert d[False] == 0

assert True in d
assert d[True] == 3
assert 0 in d
assert d[0] == 4

assert 1 in d
assert d[1] == 5

assert 1.0 in d
assert d[1.0] == 6

assert 2 in d
assert d[2] == 7

0 comments on commit c353722

Please sign in to comment.