Skip to content

Commit

Permalink
Implement eq in SentencePieceModel based on __getstate__
Browse files Browse the repository at this point in the history
This allows comparing two vocabularies without loading the model. Currently, eq is implemented by comparing the md5 checksum of the loaded models, which requires the model to be loaded. It also ignores other parameters of the vocabulary such as `extra_ids`, `reverse_extra_ids`, etc.

PiperOrigin-RevId: 684142783
  • Loading branch information
tomvdw authored and SeqIO committed Oct 15, 2024
1 parent bbb5ba3 commit fddb1d9
Showing 1 changed file with 1 addition and 9 deletions.
10 changes: 1 addition & 9 deletions seqio/vocabularies.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,15 +555,7 @@ def _decode_tf(self, ids):
def __eq__(self, other):
if not isinstance(other, SentencePieceVocabulary):
return False
try:
their_md5 = hashlib.md5(other.sp_model).hexdigest()
# If other has no sp_model attribute, we can't test for equality
except AttributeError:
return False
if self.sp_model is None:
return False
our_md5 = hashlib.md5(self.sp_model).hexdigest()
return our_md5 == their_md5
return self.__getstate__() == other.__getstate__()

def __str__(self) -> str:
return (
Expand Down

0 comments on commit fddb1d9

Please sign in to comment.