diff --git a/seqio/vocabularies.py b/seqio/vocabularies.py index 079fc81c..6895565c 100644 --- a/seqio/vocabularies.py +++ b/seqio/vocabularies.py @@ -215,13 +215,17 @@ def __str__(self) -> str: class UnigramVocabulary(Vocabulary): """Vocabulary that does table-lookup of unigrams.""" - def __init__(self, unigrams: Sequence[str]): + def __init__(self, unigrams: Sequence[str], split_on_space: bool = False): """UnigramVocabulary constructor. Args: unigrams: the collection of in-vocabulary tokens. This collection should not include PAD or UNK, which are automatically assigned ids and managed as possible decode tokens. + split_on_space: if True, encode/decode split/join with the space + character. Otherwise, follows legacy behavior: encode (and encode_tf) + treats the input as a single token, decode splits on the space + character, and decode_tf decodes only the first token. """ super().__init__() @@ -237,19 +241,33 @@ def __init__(self, unigrams: Sequence[str]): initializer, num_oov_buckets=1 ) self._unigram_by_id_tf = tf.constant(self._unigram_by_id) + self._split_on_space = split_on_space def _encode(self, s: str) -> Sequence[int]: - return [self._id_by_unigram.get(s, self.unk_id)] + if self._split_on_space: + return [ + self._id_by_unigram.get(unigram, self.unk_id) + for unigram in s.split(" ") + ] + else: + return [self._id_by_unigram.get(s, self.unk_id)] def _encode_tf(self, s: tf.Tensor) -> tf.Tensor: - tf_ids = self._id_by_unigram_tf.lookup(s) - return tf.expand_dims(tf.dtypes.cast(tf_ids, tf.int32), -1) + if self._split_on_space: + tf_ids = self._id_by_unigram_tf.lookup(tf.strings.split(s, " ")) + return tf.dtypes.cast(tf_ids, tf.int32) + else: + tf_ids = self._id_by_unigram_tf.lookup(s) + return tf.expand_dims(tf.dtypes.cast(tf_ids, tf.int32), -1) def _decode(self, ids: Sequence[int]) -> str: return " ".join(self._unigram_by_id[id] for id in ids) def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor: - return self._unigram_by_id_tf[ids[0]] + if self._split_on_space: + return tf.strings.join(tf.gather(self._unigram_by_id_tf, ids), " ") + else: + return self._unigram_by_id_tf[ids[0]] @property def _base_vocab_size(self): diff --git a/seqio/vocabularies_test.py b/seqio/vocabularies_test.py index de513cad..ec0c84f7 100644 --- a/seqio/vocabularies_test.py +++ b/seqio/vocabularies_test.py @@ -201,17 +201,22 @@ def test_not_equal(self): -class UnigramVocabularyTest(absltest.TestCase): +class UnigramVocabularyTest(parameterized.TestCase): - def test_encode_converts_unigrams_to_ints_correctly(self): + @parameterized.parameters((True,), (False,)) + def test_encode_converts_unigrams_to_ints_correctly(self, split_on_space): unigrams = ["this", "that", "is", "not", "a", "the", "test", "ball"] - vocabulary = vocabularies.UnigramVocabulary(unigrams) + vocabulary = vocabularies.UnigramVocabulary(unigrams, split_on_space) self.assertEqual(vocabulary.unk_id, 9) with self.subTest(name="pure_python"): # Note that id 0 is reserved for padding. self.assertEqual(vocabulary.encode("that"), [2]) self.assertEqual(vocabulary.encode("not"), [4]) self.assertEqual(vocabulary.encode("apple"), [vocabulary.unk_id]) + if split_on_space: + self.assertEqual(vocabulary.encode("not that"), [4, 2]) + else: + self.assertEqual(vocabulary.encode("not that"), [vocabulary.unk_id]) with self.subTest(name="tensorflow"): # Note that id 0 is reserved for padding. # Note that this test must pass under both TF1 and TF2, but the default @@ -229,14 +234,25 @@ def test_encode_converts_unigrams_to_ints_correctly(self): vocabulary.encode_tf(tf.constant("apple")).numpy(), [vocabulary.unk_id], ) + if split_on_space: + np.testing.assert_array_equal( + vocabulary.encode_tf(tf.constant("not that")).numpy(), [4, 2] + ) + else: + np.testing.assert_array_equal( + vocabulary.encode_tf(tf.constant("not that")).numpy(), + [vocabulary.unk_id], + ) - def test_decode_converts_ints_to_unigrams_correctly(self): + @parameterized.parameters((True,), (False,)) + def test_decode_converts_ints_to_unigrams_correctly(self, split_on_space): unigrams = ["this", "that", "is", "not", "a", "the", "test", "ball"] - vocabulary = vocabularies.UnigramVocabulary(unigrams) + vocabulary = vocabularies.UnigramVocabulary(unigrams, split_on_space) with self.subTest(name="pure_python"): self.assertEqual(vocabulary.decode([1]), "this") self.assertEqual(vocabulary.decode([3]), "is") self.assertEqual(vocabulary.decode([vocabulary.unk_id]), "UNK") + self.assertEqual(vocabulary.decode([1, 3]), "this is") with self.subTest(name="tensorflow"): # Note that this test must pass under both TF1 and TF2, but the default # behavior of TF1 == among tensors is to compare object references, not @@ -248,6 +264,14 @@ def test_decode_converts_ints_to_unigrams_correctly(self): self.assertEqual( vocabulary.decode_tf(tf.constant([vocabulary.unk_id])).numpy(), b"UNK" ) + if split_on_space: + self.assertEqual( + vocabulary.decode_tf(tf.constant([4, 2])).numpy(), b"not that" + ) + else: + self.assertEqual( + vocabulary.decode_tf(tf.constant([4, 2])).numpy(), b"not" + )