-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata.py
64 lines (55 loc) · 1.92 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Union, Dict
from pathlib import Path
from tape.tokenizers import TAPETokenizer
from tape.datasets import pad_sequences
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import torch
class AMPDataset(Dataset):
def __init__(
self,
data_file: Union[str, Path, pd.DataFrame],
task_label: Union[str, list] = "AMP",
max_pep_len=180,
tokenizer: Union[str, TAPETokenizer] = 'iupac',
):
if isinstance(data_file, pd.DataFrame):
data = data_file
else:
data = pd.read_csv(data_file)
if isinstance(tokenizer, str):
tokenizer = TAPETokenizer(vocab=tokenizer)
self.tokenizer = tokenizer
sequences = data['Sequence']
sequences = sequences.apply(lambda x: x[:max_pep_len])
if task_label == 'AMP':
labels = data['Label']
else:
labels = data.loc[:, task_label].astype('float') # for BCEloss
self.sequences = sequences
self.targets = labels.to_numpy()
def __len__(self) -> int:
return len(self.targets)
def __getitem__(self, index: int):
seq = self.sequences[index]
token_ids = self.tokenizer.encode(seq)
input_mask = np.ones_like(token_ids)
item = {
'input_ids': token_ids,
'input_mask': input_mask,
'target': self.targets[index]
}
return item
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
elem = batch[0]
batch = {key: [d[key] for d in batch] for key in elem}
input_ids = torch.from_numpy(pad_sequences(batch['input_ids'], 0))
input_mask = torch.from_numpy(pad_sequences(batch['input_mask'], 0))
targets = torch.tensor(batch['target'])
item = {
'input_ids': input_ids,
'input_mask': input_mask,
'targets': targets
}
return item