-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata.py
122 lines (101 loc) · 4.18 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
#!/usr/bin/env python3
import os
import torch
def _tokenize(text_path, dictionary_to_update):
"""Tokenizes a text file."""
print('Tokenizing {}'.format(text_path))
assert os.path.exists(text_path)
nb_tokens_in_dictionary = len(dictionary_to_update)
# Count nb of tokens in text and update the dictionary
with open(text_path, 'r', encoding="utf8") as f:
for line in f:
tokens = line.split() + ['<eos>']
for token in tokens:
if token not in dictionary_to_update:
dictionary_to_update[token] = nb_tokens_in_dictionary
nb_tokens_in_dictionary += 1
# Assign to each token its identifier
ids = []
with open(text_path, 'r', encoding="utf8") as f:
for line in f:
tokens = line.split() + ['<eos>']
for token in tokens:
ids.append(dictionary_to_update[token])
ids = torch.LongTensor(ids)
return ids
class Corpus:
def __init__(self, data_path):
self._dictionary = {}
self.train = _tokenize(
text_path=os.path.join(data_path, 'train.txt'),
dictionary_to_update=self._dictionary)
self.valid = _tokenize(
text_path=os.path.join(data_path, 'valid.txt'),
dictionary_to_update=self._dictionary)
self.test = _tokenize(
text_path=os.path.join(data_path, 'test.txt'),
dictionary_to_update=self._dictionary)
@property
def vocab_size(self):
return len(self._dictionary)
def _batchify(data_tensor, batch_size):
nb_batches = data_tensor.size(0) // batch_size
# trim away some tokens to make whole batches
data_tensor = data_tensor.narrow(0, 0, nb_batches * batch_size)
data_tensor = data_tensor.view(batch_size, -1).contiguous()
return data_tensor
def _build_corpus(data_path, env_params):
# save the corpus to a file so that it's faster next time
corpus_path = os.path.join(data_path, 'corpus.pt')
if os.path.exists(corpus_path):
print('Loading an existing corpus file from {}'.format(corpus_path))
corpus = torch.load(corpus_path)
else:
print('Creating a corpus file at {}'.format(corpus_path))
if env_params['distributed']:
# only one process need to create a corpus file
if env_params['rank'] == 0:
corpus = Corpus(data_path)
torch.save(corpus, corpus_path)
# sync with other processes
torch.distributed.broadcast(torch.zeros(1).cuda(), src=0)
else:
print('Waiting rank0 to create a corpus file.')
# sync with rank0
torch.distributed.broadcast(torch.zeros(1).cuda(), src=0)
corpus = torch.load(corpus_path)
else:
corpus = Corpus(data_path)
torch.save(corpus, corpus_path)
return corpus
def _get_train_val_test_data(corpus, batch_size):
return [
_batchify(corpus.train, batch_size),
_batchify(corpus.valid, batch_size),
_batchify(corpus.test, batch_size)
]
def get_train_val_test_data(data_params, env_params, batch_size, device):
corpus = _build_corpus(**data_params, env_params=env_params)
data_params['vocab_size'] = corpus.vocab_size
train_data, val_data, test_data = _get_train_val_test_data(
corpus=corpus, batch_size=batch_size)
if env_params['distributed']:
# split the data into equal parts
assert batch_size % env_params['world_size'] == 0
device_batch_size = batch_size // env_params['world_size']
slice_data = slice(
device_batch_size * env_params['rank'],
device_batch_size * (env_params['rank'] + 1))
train_data = train_data[slice_data]
val_data = val_data[slice_data]
test_data = test_data[slice_data]
train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)
return train_data, val_data, test_data