diff --git a/alpha_automl/wrapper_primitives/huggingface_text.py b/alpha_automl/wrapper_primitives/huggingface_text.py index 383d304f..c6e0be41 100644 --- a/alpha_automl/wrapper_primitives/huggingface_text.py +++ b/alpha_automl/wrapper_primitives/huggingface_text.py @@ -13,9 +13,10 @@ class HuggingfaceTextTransformer(BasePrimitive): - def __init__(self, name, tokenizer=None): + def __init__(self, name, tokenizer=None, max_length=512): self.name = name self.tokenizer = tokenizer if tokenizer else name + self.max_length = max_length def fit(self, X, y=None): return self @@ -44,10 +45,8 @@ def transform(self, texts): batch_texts = list_texts[start: start + batch_size] # batch_texts = [' '.join(line.split()) if str(line)!='nan' else '' for line in batch_texts] - ids = tokenizer(batch_texts, padding=True, return_tensors="pt") - ids['input_ids'] = ids['input_ids'][:, :512] - ids['token_type_ids'] = ids['token_type_ids'][:, :512] - ids['attention_mask'] = ids['attention_mask'][:, :512] + ids = tokenizer(batch_texts, padding=True, return_tensors="pt", max_length=self.max_length) + device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) ids = ids.to(device)