Skip to content

Commit

Permalink
Give UnigramVocabulary the option to split/join on the space characte…
Browse files Browse the repository at this point in the history
…r in all encode/decode functions.

Previously, decode joined on space, while encode/encode_tf/decode_tf all just ignored tokens after the first. Now if "split_on_space" is True, all four functions are consistent with decode.

Having a UnigramVocabulary that encodes/decodes invertibly is useful for testing.

PiperOrigin-RevId: 651126298
  • Loading branch information
galenmandrew authored and SeqIO committed Jul 10, 2024
1 parent 568f9c4 commit e88bb6c
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 10 deletions.
28 changes: 23 additions & 5 deletions seqio/vocabularies.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,17 @@ def __str__(self) -> str:
class UnigramVocabulary(Vocabulary):
"""Vocabulary that does table-lookup of unigrams."""

def __init__(self, unigrams: Sequence[str]):
def __init__(self, unigrams: Sequence[str], split_on_space: bool = False):
"""UnigramVocabulary constructor.
Args:
unigrams: the collection of in-vocabulary tokens. This collection should
not include PAD or UNK, which are automatically assigned ids and managed
as possible decode tokens.
split_on_space: if True, encode/decode split/join with the space
character. Otherwise, follows legacy behavior: encode (and encode_tf)
treats the input as a single token, decode splits on the space
character, and decode_tf decodes only the first token.
"""

super().__init__()
Expand All @@ -237,19 +241,33 @@ def __init__(self, unigrams: Sequence[str]):
initializer, num_oov_buckets=1
)
self._unigram_by_id_tf = tf.constant(self._unigram_by_id)
self._split_on_space = split_on_space

def _encode(self, s: str) -> Sequence[int]:
return [self._id_by_unigram.get(s, self.unk_id)]
if self._split_on_space:
return [
self._id_by_unigram.get(unigram, self.unk_id)
for unigram in s.split(" ")
]
else:
return [self._id_by_unigram.get(s, self.unk_id)]

def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
tf_ids = self._id_by_unigram_tf.lookup(s)
return tf.expand_dims(tf.dtypes.cast(tf_ids, tf.int32), -1)
if self._split_on_space:
tf_ids = self._id_by_unigram_tf.lookup(tf.strings.split(s, " "))
return tf.dtypes.cast(tf_ids, tf.int32)
else:
tf_ids = self._id_by_unigram_tf.lookup(s)
return tf.expand_dims(tf.dtypes.cast(tf_ids, tf.int32), -1)

def _decode(self, ids: Sequence[int]) -> str:
return " ".join(self._unigram_by_id[id] for id in ids)

def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
return self._unigram_by_id_tf[ids[0]]
if self._split_on_space:
return tf.strings.join(tf.gather(self._unigram_by_id_tf, ids), " ")
else:
return self._unigram_by_id_tf[ids[0]]

@property
def _base_vocab_size(self):
Expand Down
34 changes: 29 additions & 5 deletions seqio/vocabularies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,22 @@ def test_not_equal(self):



class UnigramVocabularyTest(absltest.TestCase):
class UnigramVocabularyTest(parameterized.TestCase):

def test_encode_converts_unigrams_to_ints_correctly(self):
@parameterized.parameters((True,), (False,))
def test_encode_converts_unigrams_to_ints_correctly(self, split_on_space):
unigrams = ["this", "that", "is", "not", "a", "the", "test", "ball"]
vocabulary = vocabularies.UnigramVocabulary(unigrams)
vocabulary = vocabularies.UnigramVocabulary(unigrams, split_on_space)
self.assertEqual(vocabulary.unk_id, 9)
with self.subTest(name="pure_python"):
# Note that id 0 is reserved for padding.
self.assertEqual(vocabulary.encode("that"), [2])
self.assertEqual(vocabulary.encode("not"), [4])
self.assertEqual(vocabulary.encode("apple"), [vocabulary.unk_id])
if split_on_space:
self.assertEqual(vocabulary.encode("not that"), [4, 2])
else:
self.assertEqual(vocabulary.encode("not that"), [vocabulary.unk_id])
with self.subTest(name="tensorflow"):
# Note that id 0 is reserved for padding.
# Note that this test must pass under both TF1 and TF2, but the default
Expand All @@ -229,14 +234,25 @@ def test_encode_converts_unigrams_to_ints_correctly(self):
vocabulary.encode_tf(tf.constant("apple")).numpy(),
[vocabulary.unk_id],
)
if split_on_space:
np.testing.assert_array_equal(
vocabulary.encode_tf(tf.constant("not that")).numpy(), [4, 2]
)
else:
np.testing.assert_array_equal(
vocabulary.encode_tf(tf.constant("not that")).numpy(),
[vocabulary.unk_id],
)

def test_decode_converts_ints_to_unigrams_correctly(self):
@parameterized.parameters((True,), (False,))
def test_decode_converts_ints_to_unigrams_correctly(self, split_on_space):
unigrams = ["this", "that", "is", "not", "a", "the", "test", "ball"]
vocabulary = vocabularies.UnigramVocabulary(unigrams)
vocabulary = vocabularies.UnigramVocabulary(unigrams, split_on_space)
with self.subTest(name="pure_python"):
self.assertEqual(vocabulary.decode([1]), "this")
self.assertEqual(vocabulary.decode([3]), "is")
self.assertEqual(vocabulary.decode([vocabulary.unk_id]), "UNK")
self.assertEqual(vocabulary.decode([1, 3]), "this is")
with self.subTest(name="tensorflow"):
# Note that this test must pass under both TF1 and TF2, but the default
# behavior of TF1 == among tensors is to compare object references, not
Expand All @@ -248,6 +264,14 @@ def test_decode_converts_ints_to_unigrams_correctly(self):
self.assertEqual(
vocabulary.decode_tf(tf.constant([vocabulary.unk_id])).numpy(), b"UNK"
)
if split_on_space:
self.assertEqual(
vocabulary.decode_tf(tf.constant([4, 2])).numpy(), b"not that"
)
else:
self.assertEqual(
vocabulary.decode_tf(tf.constant([4, 2])).numpy(), b"not"
)



Expand Down

0 comments on commit e88bb6c

Please sign in to comment.