Skip to content

Commit

Permalink
Fix overflow bug in ByteVocabulary._encode_tf
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570396249
  • Loading branch information
SeqIO Team authored and SeqIO committed Oct 3, 2023
1 parent 29c70c0 commit bae5d7e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions seqio/vocabularies.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,10 @@ def _encode_tf(self, s):
Returns:
a 1d tf.Tensor with dtype tf.int32
"""
tf_ids = tf.io.decode_raw(s, tf.uint8) + self._num_special_tokens
return tf.dtypes.cast(tf_ids, tf.int32)
return (
tf.dtypes.cast(tf.io.decode_raw(s, tf.uint8), tf.int32)
+ self._num_special_tokens
)

def _decode_tf(self, ids):
"""Decode in TensorFlow.
Expand Down

0 comments on commit bae5d7e

Please sign in to comment.