From a62e9ab9891f0d890a87b7e6b94b6f0e380a2e25 Mon Sep 17 00:00:00 2001 From: Tyler Barrus Date: Wed, 3 Jan 2024 20:20:31 -0500 Subject: [PATCH] 37 implement a quotient filter (#111) * basic quotient filter, to be expanded upon * basic bit array implementation --- .pylintrc | 2 +- CHANGELOG.md | 5 + docs/source/code.rst | 35 ++++ docs/source/conf.py | 9 +- docs/source/index.rst | 1 + docs/source/quickstart.rst | 28 +++ probables/__init__.py | 20 +- probables/blooms/bloom.py | 16 +- probables/hashes.py | 22 +- probables/quotientfilter/__init__.py | 6 + probables/quotientfilter/py.typed | 0 probables/quotientfilter/quotientfilter.py | 231 +++++++++++++++++++++ probables/utilities.py | 114 ++++++++++ tests/hashes_test.py | 8 + tests/quotientfilter_test.py | 93 +++++++++ tests/test_utilities.py | 84 +++++++- 16 files changed, 642 insertions(+), 32 deletions(-) create mode 100644 probables/quotientfilter/__init__.py create mode 100644 probables/quotientfilter/py.typed create mode 100644 probables/quotientfilter/quotientfilter.py create mode 100644 tests/quotientfilter_test.py diff --git a/.pylintrc b/.pylintrc index 3f8d312..f101b55 100644 --- a/.pylintrc +++ b/.pylintrc @@ -535,7 +535,7 @@ function-naming-style=snake_case #function-rgx= # Good variable names which should always be accepted, separated by a comma. -good-names=i,j,k,b,f,v,m,n,p,d,hh,st,ex,Run,_ +good-names=i,j,k,b,f,v,m,n,p,d,hh,st,ex,Run,_,r,q # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted diff --git a/CHANGELOG.md b/CHANGELOG.md index e8dd3f4..6bea312 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # PyProbables Changelog +### Version 0.6.0 + +* Add `QuotientFilter` implementation; [see issue #37](https://github.com/barrust/pyprobables/issues/37) +* Add `bitarray` implementation + ### Version 0.5.9 * Add `py.typed` files so that mypy will find type annotations diff --git a/docs/source/code.rst b/docs/source/code.rst index 8a5c310..88e58a2 100644 --- a/docs/source/code.rst +++ b/docs/source/code.rst @@ -19,6 +19,7 @@ operations. Bloom Filters guarantee a zero percent false negative rate and a predetermined false positive rate. Once the number of elements inserted exceeds the estimated elements, the false positive rate will increase over the desired amount. + `Further Reading `__ @@ -69,6 +70,7 @@ membership testing. Cuckoo filters support insertion, deletion, and lookup of elements with low overhead and few false positive results. The name is derived from the `cuckoo hashing `__ strategy used to resolve conflicts. + `Further Reading `__ CuckooFilter @@ -92,6 +94,7 @@ data elements. The result is a probabilistic count of elements inserted into the data structure. It will always provide the **maximum** number of times a data element was encountered. Notice that the result may be **more** than the true number of times it was inserted, but never fewer. + `Further Reading `__ @@ -137,6 +140,38 @@ StreamThreshold For more information of all methods and properties, see `CountMinSketch`_. +QuotientFilter +------------------ + +Quotient filters are an aproximate membership query filter (AMQ) that is both +space efficient and returns a zero false negative rate and a probablistic false +positive rate. Unlike Bloom filters, the quotient filter only requires a single +hash of the element to insert. The upper **q** bits denote the location within the +filter while the lower **r** bits are stored in the filter. + +Quotient filters provide some useful benifits over Bloom filters including: + +* Merging of two filters (not union) +* Resizing of the filter +* Ability to remove elements + +`Further Reading `__ + +QuotientFilter ++++++++++++++++++++++++++++++++ + +.. autoclass:: probables.QuotientFilter + :members: + + +Utilities +------------------ + +Bitarray ++++++++++++++++++++++++++++++++ + +.. autoclass:: probables.utilities.Bitarray + :members: Exceptions ============================ diff --git a/docs/source/conf.py b/docs/source/conf.py index 66f98df..f189d6e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -103,7 +103,14 @@ # further. For a list of options available for each theme, see the # documentation. # -# html_theme_options = {} + +html_theme_options = { + # "collapse_navigation": True, + # "sticky_navigation": True, + # "navigation_depth": 4, + # "includehidden": True, + # "titles_only": False, +} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/docs/source/index.rst b/docs/source/index.rst index 4935802..5b6e468 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,6 +1,7 @@ .. _home: .. include:: ../../README.rst + .. toctree:: code diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index e3f5d6c..d97552e 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -253,6 +253,34 @@ The counting cuckoo filter is similar to the standard filter except that it tracks the number of times a fingerprint has been added to the filter. +Quotient Filters +---------------- + +Quotient Filters provide set operations of large datasets while being relatively +small in memory footprint. They provide a zero percent false negative rate and a +small false positive rate. +`more information `__ + + +Import, Initialize, and Train +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. code:: python3 + + >>> qf = QuotientFilter(quotient=22) + >>> with open('war_and_peace.txt', 'r') as fp: + >>> for line in fp: + >>> for word in line.split(): + >>> blm.add(word.lower()) # add each word to the bloom filter! + + +Query the Quotient Filter +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. code:: python3 + + >>> words_to_check = ['borzoi', 'diametrically', 'fleches', 'rain', 'foo'] + >>> for word in words_to_check: + >>> print(qf.check(word)) # prints: True, True, True, True, False + Custom Hashing Functions ---------------------------------- In many instances, to get the best raw performance out of the data structures, diff --git a/probables/__init__.py b/probables/__init__.py index 845434c..84414e5 100644 --- a/probables/__init__.py +++ b/probables/__init__.py @@ -1,19 +1,7 @@ """ pyprobables module """ -from .blooms import ( - BloomFilter, - BloomFilterOnDisk, - CountingBloomFilter, - ExpandingBloomFilter, - RotatingBloomFilter, -) -from .countminsketch import ( - CountMeanMinSketch, - CountMeanSketch, - CountMinSketch, - HeavyHitters, - StreamThreshold, -) +from .blooms import BloomFilter, BloomFilterOnDisk, CountingBloomFilter, ExpandingBloomFilter, RotatingBloomFilter +from .countminsketch import CountMeanMinSketch, CountMeanSketch, CountMinSketch, HeavyHitters, StreamThreshold from .cuckoo import CountingCuckooFilter, CuckooFilter from .exceptions import ( CuckooFilterFullError, @@ -22,6 +10,8 @@ ProbablesBaseException, RotatingBloomFilterError, ) +from .quotientfilter import QuotientFilter +from .utilities import Bitarray __author__ = "Tyler Barrus" __maintainer__ = "Tyler Barrus" @@ -50,4 +40,6 @@ "ExpandingBloomFilter", "RotatingBloomFilter", "RotatingBloomFilterError", + "QuotientFilter", + "Bitarray", ] diff --git a/probables/blooms/bloom.py b/probables/blooms/bloom.py index edb4658..5a47693 100644 --- a/probables/blooms/bloom.py +++ b/probables/blooms/bloom.py @@ -286,7 +286,7 @@ def export(self, file: Union[Path, str, IOBase, mmap]) -> None: """Export the Bloom Filter to disk Args: - filename (str): The filename to which the Bloom Filter will be written.""" + file (str): The file or filepath to which the Bloom Filter will be written.""" if not isinstance(file, (IOBase, mmap)): file = resolve_path(file) with open(file, "wb") as filepointer: @@ -658,23 +658,23 @@ def close(self) -> None: self.__file_pointer.close() self.__file_pointer = None - def export(self, filename: Union[str, Path]) -> None: # type: ignore + def export(self, file: Union[str, Path]) -> None: # type: ignore """Export to disk if a different location Args: - filename (str): The filename to which the Bloom Filter will be exported + file (str|Path): The filename to which the Bloom Filter will be exported Note: Only exported if the filename is not the original filename""" self.__update() - if filename and Path(filename) != self._filepath: - copyfile(self._filepath.name, str(filename)) + if file and Path(file) != self._filepath: + copyfile(self._filepath.name, str(file)) # otherwise, nothing to do! - def _load(self, filepath: Union[str, Path], hash_function: Union[HashFuncT, None] = None): # type: ignore + def _load(self, file: Union[str, Path], hash_function: Union[HashFuncT, None] = None): # type: ignore """load the Bloom Filter on disk""" # read the file, set the optimal params # mmap everything - file = resolve_path(filepath) + file = resolve_path(file) with open(file, "r+b") as filepointer: offset = self._FOOTER_STRUCT.size filepointer.seek(offset * -1, os.SEEK_END) @@ -683,7 +683,7 @@ def _load(self, filepath: Union[str, Path], hash_function: Union[HashFuncT, None fpr, n_hashes, n_bits = self._get_optimized_params(est_els, fpr) self._set_values(est_els, fpr, n_hashes, n_bits, hash_function) # setup a few additional items - self.__file_pointer = open(filepath, "r+b") # type: ignore + self.__file_pointer = open(file, "r+b") # type: ignore self._bloom = mmap(self.__file_pointer.fileno(), 0) # type: ignore self._on_disk = True diff --git a/probables/hashes.py b/probables/hashes.py index 0f90bec..00b9577 100644 --- a/probables/hashes.py +++ b/probables/hashes.py @@ -5,7 +5,7 @@ from struct import unpack from typing import Callable, List, Union -from .constants import UINT64_T_MAX +from .constants import UINT32_T_MAX, UINT64_T_MAX KeyT = Union[str, bytes] SimpleHashT = Callable[[KeyT, int], int] @@ -103,6 +103,26 @@ def fnv_1a(key: KeyT, seed: int = 0) -> int: return hval +def fnv_1a_32(key: KeyT, seed: int = 0) -> int: + """Pure python implementation of the 32 bit fnv-1a hash + Args: + key (str): The element to be hashed + seed (int): Add a seed to the initial starting point (0 means no seed) + Returns: + int: 32-bit hashed representation of key + Note: + Uses the lower 32 bits when overflows occur""" + max32mod = UINT32_T_MAX + 1 + hval = (0x811C9DC5 + (31 * seed)) % max32mod + fnv_32_prime = 0x01000193 + tmp = list(key) if not isinstance(key, str) else list(map(ord, key)) + for t_str in tmp: + hval ^= t_str + hval *= fnv_32_prime + hval %= max32mod + return hval + + @hash_with_depth_bytes def default_md5(key: KeyT, *args, **kwargs) -> bytes: """The default md5 hashing routine diff --git a/probables/quotientfilter/__init__.py b/probables/quotientfilter/__init__.py new file mode 100644 index 0000000..13f2444 --- /dev/null +++ b/probables/quotientfilter/__init__.py @@ -0,0 +1,6 @@ +""" Quotient Filters """ + + +from .quotientfilter import QuotientFilter + +__all__ = ["QuotientFilter"] diff --git a/probables/quotientfilter/py.typed b/probables/quotientfilter/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/probables/quotientfilter/quotientfilter.py b/probables/quotientfilter/quotientfilter.py new file mode 100644 index 0000000..90d3322 --- /dev/null +++ b/probables/quotientfilter/quotientfilter.py @@ -0,0 +1,231 @@ +""" BloomFilter and BloomFiter on Disk, python implementation + License: MIT + Author: Tyler Barrus (barrust@gmail.com) +""" + +from array import array + +from probables.hashes import HashFuncT, KeyT, fnv_1a_32 +from probables.utilities import Bitarray + + +class QuotientFilter: + """Simple Quotient Filter implementation + + Args: + quotient (int): The size of the quotient to use + hash_function (function): Hashing strategy function to use `hf(key, number)` + Returns: + QuotientFilter: The initialized filter + Raises: + ValueError: + Note: + The size of the QuotientFilter will be 2**q""" + + __slots__ = ( + "_q", + "_r", + "_size", + "_elements_added", + "_hash_func", + "_int_type_code", + "_bits_per_elm", + "_is_occupied", + "_is_continuation", + "_is_shifted", + "_filter", + ) + + def __init__(self, quotient: int = 20, hash_function: HashFuncT = None): # needs to be parameterized + if quotient < 3 or quotient > 31: + raise ValueError( + f"Quotient filter: Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided" + ) + self._q = quotient + self._r = 32 - quotient + self._size = 1 << self._q # same as 2**q + self._elements_added = 0 + self._hash_func = fnv_1a_32 if hash_function is None else hash_function + + # ensure we use the smallest type possible to reduce memory wastage + if self._r <= 8: + self._int_type_code = "B" + self._bits_per_elm = 8 + elif self._r <= 16: + self._int_type_code = "I" + self._bits_per_elm = 16 + else: + self._int_type_code = "L" + self._bits_per_elm = 32 + + self._is_occupied = Bitarray(self._size) + self._is_continuation = Bitarray(self._size) + self._is_shifted = Bitarray(self._size) + self._filter = array(self._int_type_code, [0]) * self._size + + def __contains__(self, val: KeyT) -> bool: + """setup the `in` keyword""" + return self.check(val) + + @property + def quotient(self) -> int: + """int: The size of the quotient, in bits""" + return self._q + + @property + def remainder(self) -> int: + """int: The size of the remainder, in bits""" + return self._r + + @property + def num_elements(self) -> int: + """int: The total size of the filter""" + return self._size + + @property + def elements_added(self) -> int: + """int: The number of elements added to the filter""" + return self._elements_added + + @property + def bits_per_elm(self): + """int: The number of bits used per element""" + return self._bits_per_elm + + def add(self, key: KeyT) -> None: + """Add key to the quotient filter + + Args: + key (str|bytes): The element to add""" + _hash = self._hash_func(key) + key_quotient = _hash >> self._r + key_remainder = _hash & ((1 << self._r) - 1) + + if not self._contains(key_quotient, key_remainder): + # TODO, add it here + self._add(key_quotient, key_remainder) + + def check(self, key: KeyT) -> bool: + """Check to see if key is likely in the quotient filter + + Args: + key (str|bytes): The element to add + Return: + bool: True if likely encountered, False if definately not""" + _hash = self._hash_func(key) + key_quotient = _hash >> self._r + key_remainder = _hash & ((1 << self._r) - 1) + return self._contains(key_quotient, key_remainder) + + def _shift_insert(self, k, v, start, j, flag): + if self._is_occupied[j] == 0 and self._is_continuation[j] == 0 and self._is_shifted[j] == 0: + self._filter[j] = v + self._is_occupied[k] = 1 + self._is_continuation[j] = 1 if j != start else 0 + self._is_shifted[j] = 1 if j != k else 0 + + else: + i = (j + 1) & (self._size - 1) + + while True: + f = self._is_occupied[i] + self._is_continuation[i] + self._is_shifted[i] + + temp = self._is_continuation[i] + self._is_continuation[i] = self._is_continuation[j] + self._is_continuation[j] = temp + + self._is_shifted[i] = 1 + + temp = self._filter[i] + self._filter[i] = self._filter[j] + self._filter[j] = temp + + if f == 0: + break + + i = (i + 1) & (self._size - 1) + + self._filter[j] = v + self._is_occupied[k] = 1 + self._is_continuation[j] = 1 if j != start else 0 + self._is_shifted[j] = 1 if j != k else 0 + + if flag == 1: + self._is_continuation[(j + 1) & (self._size - 1)] = 1 + + def _get_start_index(self, k): + j = k + cnts = 0 + + while True: + if j == k or self._is_occupied[j] == 1: + cnts += 1 + + if self._is_shifted[j] == 1: + j = (j - 1) & (self._size - 1) + else: + break + + while True: + if self._is_continuation[j] == 0: + if cnts == 1: + break + cnts -= 1 + + j = (j + 1) & (self._size - 1) + + return j + + def _add(self, q: int, r: int): + if self._is_occupied[q] == 0 and self._is_continuation[q] == 0 and self._is_shifted[q] == 0: + self._filter[q] = r + self._is_occupied[q] = 1 + + else: + start_idx = self._get_start_index(q) + + if self._is_occupied[q] == 0: + self._shift_insert(q, r, start_idx, start_idx, 0) + + else: + orig_start_idx = start_idx + starts = 0 + f = self._is_occupied[start_idx] + self._is_continuation[start_idx] + self._is_shifted[start_idx] + + while starts == 0 and f != 0 and r > self._filter[start_idx]: + start_idx = (start_idx + 1) & (self._size - 1) + + if self._is_continuation[start_idx] == 0: + starts += 1 + + f = self._is_occupied[start_idx] + self._is_continuation[start_idx] + self._is_shifted[start_idx] + + if starts == 1: + self._shift_insert(q, r, orig_start_idx, start_idx, 0) + else: + self._shift_insert(q, r, orig_start_idx, start_idx, 1) + self._elements_added += 1 + + def _contains(self, q: int, r: int) -> bool: + if self._is_occupied[q] == 0: + return False + + start_idx = self._get_start_index(q) + + starts = 0 + meta_bits = self._is_occupied[start_idx] + self._is_continuation[start_idx] + self._is_shifted[start_idx] + + while meta_bits != 0: + if self._is_continuation[start_idx] == 0: + starts += 1 + + if starts == 2 or self._filter[start_idx] > r: + break + + if self._filter[start_idx] == r: + return True + + start_idx = (start_idx + 1) & (self._size - 1) + meta_bits = self._is_occupied[start_idx] + self._is_continuation[start_idx] + self._is_shifted[start_idx] + + return False diff --git a/probables/utilities.py b/probables/utilities.py index 9d41629..2bd1bc1 100644 --- a/probables/utilities.py +++ b/probables/utilities.py @@ -1,7 +1,9 @@ """ Utility Functions """ +import math import mmap import string +from array import array from pathlib import Path from typing import Union @@ -83,3 +85,115 @@ def seek(self, pos: int, whence: int) -> None: def read(self, n: int = -1) -> bytes: """Implement a method to read from the file on top of the MMap class""" return self.__m.read(n) + + +class Bitarray: + """Simplified, pure python bitarray implementation using as little memory as possible + + Args: + size (int): The number of bits in the bitarray + Returns: + Bitarray: A bitarray + Raises: + TypeError: + ValueError:""" + + def __init__(self, size: int): + if not isinstance(size, int): + raise TypeError(f"Bitarray size must be an int; {type(size)} was provided") + if size <= 0: + raise ValueError(f"Bitarray size must be larger than 1; {size} was provided") + self._size_bytes = math.ceil(size / 8) + self._bitarray = array("B", [0]) * self._size_bytes + self._size = size + + @property + def size_bytes(self) -> int: + """The size of the bitarray in bytes""" + return self._size_bytes + + @property + def size(self) -> int: + """The number of bits in the bitarray""" + return self._size + + @property + def bitarray(self) -> array: + """The bitarray""" + return self._bitarray + + def __getitem__(self, key: Union[int, slice]) -> int: + if isinstance(key, slice): + indices = range(*key.indices(self._size)) + return [self.check_bit(i) for i in indices] + return self.check_bit(key) + + def __setitem__(self, idx: int, val: int): + if val < 0 or val > 1: + raise ValueError("Invalid bit setting; must be 0 or 1") + if idx < 0 or idx >= self._size: + raise IndexError(f"Bitarray index outside of range; index {idx} was provided") + b = idx // 8 + if val == 1: + self._bitarray[b] = self._bitarray[b] | (1 << (idx % 8)) + else: + self._bitarray[b] = self._bitarray[b] & ~(1 << (idx % 8)) + + def check_bit(self, idx: int) -> int: + """Check if the bit idx is set + + Args: + idx (int): The index to check + Returns: + int: The status of the bit, either 0 or 1""" + if idx < 0 or idx >= self._size: + raise IndexError(f"Bitarray index outside of range; index {idx} was provided") + return 0 if (self._bitarray[idx // 8] & (1 << (idx % 8))) == 0 else 1 + + def is_bit_set(self, idx: int) -> bool: + """Check if the bit idx is set + + Args: + idx (int): The index to check + Returns: + int: The status of the bit, either 0 or 1""" + return bool(self.check_bit(idx)) + + def set_bit(self, idx: int) -> None: + """Set the bit at idx to 1 + + Args: + idx (int): The index to set""" + if idx < 0 or idx >= self._size: + raise IndexError(f"Bitarray index outside of range; index {idx} was provided") + b = idx // 8 + self._bitarray[b] = self._bitarray[b] | (1 << (idx % 8)) + + def clear_bit(self, idx: int) -> None: + """Set the bit at idx to 0 + + Args: + idx (int): The index to clear""" + if idx < 0 or idx >= self._size: + raise IndexError(f"Bitarray index outside of range; index {idx} was provided") + b = idx // 8 + self._bitarray[b] = self._bitarray[b] & ~(1 << (idx % 8)) + + def clear(self): + """Clear all bits in the bitarray""" + for i in range(self._size_bytes): + self._bitarray[i] = 0 + + def as_string(self): + """String representation of the bitarray + + Returns: + str: Bitarray representation as a string""" + return "".join(str(self.check_bit(x)) for x in range(self._size)) + + def num_bits_set(self) -> int: + """Number of bits set in the bitarray + + Returns: + int: Number of bits set""" + return sum(self.check_bit(x) for x in range(self._size)) diff --git a/tests/hashes_test.py b/tests/hashes_test.py index 0f10af9..0fa7cb7 100755 --- a/tests/hashes_test.py +++ b/tests/hashes_test.py @@ -16,6 +16,7 @@ default_fnv_1a, default_md5, default_sha256, + fnv_1a_32, hash_with_depth_bytes, hash_with_depth_int, ) @@ -54,6 +55,13 @@ def test_default_hash_colision(self): for i in range(1, 5): self.assertNotEqual(h1[i], h2[i]) + def test_fnv_1a_32(self): + """test fnv_1a 32 bit hash""" + hash = fnv_1a_32("this is a test", 0) + self.assertEqual(hash, 2139996864) + hash = fnv_1a_32("this is also a test", 0) + self.assertEqual(hash, 1462718619) + def test_default_md5(self): """test default md5 algorithm""" this_is_a_test = [ diff --git a/tests/quotientfilter_test.py b/tests/quotientfilter_test.py new file mode 100644 index 0000000..8c6c8b7 --- /dev/null +++ b/tests/quotientfilter_test.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" Unittest class """ + +import hashlib +import os +import sys +import unittest +from pathlib import Path +from tempfile import NamedTemporaryFile + +this_dir = Path(__file__).parent +sys.path.insert(0, str(this_dir)) +sys.path.insert(0, str(this_dir.parent)) +from utilities import calc_file_md5, different_hash + +from probables import QuotientFilter + +DELETE_TEMP_FILES = True + + +class TestQuotientFilter(unittest.TestCase): + """Test the default quotient filter implementation""" + + def test_qf_init(self): + "test initializing a blank quotient filter" + qf = QuotientFilter() + + self.assertEqual(qf.bits_per_elm, 16) + self.assertEqual(qf.quotient, 20) + self.assertEqual(qf.remainder, 12) + self.assertEqual(qf.elements_added, 0) + self.assertEqual(qf.num_elements, 1048576) # 2**qf.quotient + + qf = QuotientFilter(quotient=8) + + self.assertEqual(qf.bits_per_elm, 32) + self.assertEqual(qf.quotient, 8) + self.assertEqual(qf.remainder, 24) + self.assertEqual(qf.elements_added, 0) + self.assertEqual(qf.num_elements, 256) # 2**qf.quotient + + qf = QuotientFilter(quotient=24) + + self.assertEqual(qf.bits_per_elm, 8) + self.assertEqual(qf.quotient, 24) + self.assertEqual(qf.remainder, 8) + self.assertEqual(qf.elements_added, 0) + self.assertEqual(qf.num_elements, 16777216) # 2**qf.quotient + + def test_qf_add_check(self): + "test that the qf is able to add and check elements" + qf = QuotientFilter(quotient=8) + + for i in range(0, 200, 2): + qf.add(str(i)) + self.assertEqual(qf.elements_added, 100) + + found_no = False + for i in range(0, 200, 2): + if not qf.check(str(i)): + found_no = True + self.assertFalse(found_no) + + for i in range(1, 200, 2): + print(i) + self.assertFalse(qf.check(str(i))) + + self.assertEqual(qf.elements_added, 100) + + def test_qf_add_check_in(self): + "test that the qf is able to add and check elements using `in`" + qf = QuotientFilter(quotient=8) + + for i in range(0, 200, 2): + qf.add(str(i)) + self.assertEqual(qf.elements_added, 100) + + found_no = False + for i in range(0, 200, 2): + if str(i) not in qf: + found_no = True + self.assertFalse(found_no) + + for i in range(1, 200, 2): + print(i) + self.assertFalse(str(i) in qf) + + self.assertEqual(qf.elements_added, 100) + + def test_qf_errors(self): + self.assertRaises(ValueError, lambda: QuotientFilter(quotient=2)) + self.assertRaises(ValueError, lambda: QuotientFilter(quotient=32)) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index ba381b4..2180beb 100755 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -13,13 +13,7 @@ from utilities import different_hash -from probables.utilities import ( - MMap, - get_x_bits, - is_hex_string, - is_valid_file, - resolve_path, -) +from probables.utilities import Bitarray, MMap, get_x_bits, is_hex_string, is_valid_file, resolve_path DELETE_TEMP_FILES = True @@ -115,6 +109,82 @@ def test_resolve_path(self): p2 = resolve_path("./{}".format(fobj.name)) self.assertTrue(p2.is_absolute()) + def test_bitarray(self): + """test bit array basic operations""" + ba = Bitarray(100) + + self.assertEqual(ba.size, 100) + self.assertEqual(ba.size_bytes, 13) + for i in range(ba.size_bytes): + self.assertEqual(0, ba.bitarray[i]) + + # test setting bits + for i in range(33): + ba.set_bit(i * 3) + + self.assertEqual( + ba.as_string(), + "1001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001000", + ) + self.assertEqual(ba.num_bits_set(), 33) + self.assertTrue(ba.is_bit_set(3)) + self.assertFalse(ba.is_bit_set(4)) + self.assertEqual(ba[0], 1) + self.assertEqual(ba[1], 0) + + # test clearing bits + for i in range(33): + ba.clear_bit(i * 3) + + self.assertEqual( + ba.as_string(), + "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + ) + + for i in range(33): + ba.set_bit(i * 3) + self.assertEqual( + ba.as_string(), + "1001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001001000", + ) + + self.assertEqual(ba[-5::], [0, 1, 0, 0, 0]) + self.assertEqual(ba[2], 0) + ba[2] = 1 + self.assertEqual(ba[2], 1) + ba[2] = 0 + self.assertEqual(ba[2], 0) + + ba.clear() + self.assertEqual( + ba.as_string(), + "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + ) + + def test_bitarray_invalid_idx(self): + """use an invalid type in a jaccard index""" + self.assertRaises(TypeError, lambda: Bitarray("100")) + self.assertRaises(ValueError, lambda: Bitarray(-100)) + ba = Bitarray(10) + self.assertRaises(IndexError, lambda: ba.set_bit(12)) + self.assertRaises(IndexError, lambda: ba.set_bit(-1)) + self.assertRaises(IndexError, lambda: ba.check_bit(-1)) + self.assertRaises(IndexError, lambda: ba.check_bit(12)) + self.assertRaises(IndexError, lambda: ba.clear_bit(-1)) + self.assertRaises(IndexError, lambda: ba.clear_bit(12)) + + self.assertRaises(IndexError, lambda: ba[-1]) + self.assertRaises(IndexError, lambda: ba[12]) + + def test_set(idx, val): + ba[idx] = val + + self.assertRaises(IndexError, lambda: test_set(-1, 0)) + self.assertRaises(IndexError, lambda: test_set(12, 0)) + # set as non-valid bit value + self.assertRaises(ValueError, lambda: test_set(1, 5)) + self.assertRaises(ValueError, lambda: test_set(12, -1)) + if __name__ == "__main__": unittest.main()