diff --git a/protein_lm/tokenizer/tokenizer.py b/protein_lm/tokenizer/tokenizer.py index ee60f0c..300e317 100644 --- a/protein_lm/tokenizer/tokenizer.py +++ b/protein_lm/tokenizer/tokenizer.py @@ -51,6 +51,8 @@ def batch_encode( output = [] if max_sequence_length is None and return_tensors: max_sequence_length = max([len(sequence) for sequence in sequences]) + if add_special_tokens: + max_sequence_length += 2 if max_sequence_length is not None: sequences = [ sequence[:(max_sequence_length - 2) if add_special_tokens else max_sequence_length]