Skip to content

Commit

Permalink
Merge pull request #93 from VIDA-NYU/fix_huggingface_wrapper_not_supp…
Browse files Browse the repository at this point in the history
…ort_distilbert

fix huggingface wrapper not support distilbert
  • Loading branch information
EdenWuyifan authored Feb 9, 2024
2 parents 744fd71 + 06efc6b commit 54b7f14
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions alpha_automl/wrapper_primitives/huggingface_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

class HuggingfaceTextTransformer(BasePrimitive):

def __init__(self, name, tokenizer=None):
def __init__(self, name, tokenizer=None, max_length=512):
self.name = name
self.tokenizer = tokenizer if tokenizer else name
self.max_length = max_length

def fit(self, X, y=None):
return self
Expand Down Expand Up @@ -44,10 +45,8 @@ def transform(self, texts):
batch_texts = list_texts[start: start + batch_size]

# batch_texts = [' '.join(line.split()) if str(line)!='nan' else '' for line in batch_texts]
ids = tokenizer(batch_texts, padding=True, return_tensors="pt")
ids['input_ids'] = ids['input_ids'][:, :512]
ids['token_type_ids'] = ids['token_type_ids'][:, :512]
ids['attention_mask'] = ids['attention_mask'][:, :512]
ids = tokenizer(batch_texts, padding=True, return_tensors="pt", max_length=self.max_length)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
ids = ids.to(device)
Expand Down

0 comments on commit 54b7f14

Please sign in to comment.