-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgenerate.py
124 lines (119 loc) · 4.08 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from source.model import DefinitionModelingModel
from source.pipeline import generate
from source.datasets import Vocabulary
from source.utils import prepare_ada_vectors_from_python, prepare_w2v_vectors
from source.constants import BOS
import argparse
import torch
import json
parser = argparse.ArgumentParser(description='Script to generate using model')
parser.add_argument(
"--params", type=str, required=True,
help="path to saved model params"
)
parser.add_argument(
"--ckpt", type=str, required=True,
help="path to saved model weights"
)
parser.add_argument(
"--tau", type=float, required=True,
help="temperature to use in sampling"
)
parser.add_argument(
"--n", type=int, required=True,
help="number of samples to generate"
)
parser.add_argument(
"--length", type=int, required=True,
help="maximum length of generated samples"
)
parser.add_argument(
"--prefix", type=str, required=False,
help="prefix to read until generation starts"
)
parser.add_argument(
"--wordlist", type=str, required=False,
help="path to word list with words and contexts"
)
parser.add_argument(
"--w2v_binary_path", type=str, required=False,
help="path to binary w2v file"
)
parser.add_argument(
"--ada_binary_path", type=str, required=False,
help="path to binary ada file"
)
parser.add_argument(
"--prep_ada_path", type=str, required=False,
help="path to prep_ada.jl script"
)
args = parser.parse_args()
with open(args.params, "r") as infile:
model_params = json.load(infile)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = DefinitionModelingModel(model_params).to(device)
model.load_state_dict(torch.load(args.ckpt)["state_dict"])
voc = Vocabulary()
voc.load(model_params["voc"])
to_input = {
"model": model,
"voc": voc,
"tau": args.tau,
"n": args.n,
"length": args.length,
"device": device,
}
if model.params["pretrain"]:
to_input["prefix"] = args.prefix
print(generate(**to_input))
else:
assert args.wordlist is not None, ("to generate definitions in --pretrain "
"False mode --wordlist is required")
with open(args.wordlist, "r") as infile:
data = infile.readlines()
if model.is_w2v:
assert args.w2v_binary_path is not None, ("model.is_w2v True => "
"--w2v_binary_path is "
"required")
input_vecs = torch.from_numpy(
prepare_w2v_vectors(args.wordlist, args.w2v_binary_path)
)
if model.is_ada:
assert args.ada_binary_path is not None, ("model.is_ada True => "
"--ada_binary_path is "
"required")
assert args.prep_ada_path is not None, ("model.is_ada True => "
"--prep_ada_path is "
"required")
input_vecs = torch.from_numpy(
prepare_ada_vectors_from_python(
args.wordlist,
args.prep_ada_path,
args.ada_binary_path
)
)
if model.is_attn:
context_voc = Vocabulary()
context_voc.load(model.params["context_voc"])
to_input["context_voc"] = context_voc
if model.params["use_ch"]:
ch_voc = Vocabulary()
ch_voc.load(model.params["ch_voc"])
to_input["ch_voc"] = ch_voc
for i in range(len(data)):
word, context = data[i].split('\t')
context = context.strip()
if model.is_w2v or model.is_ada:
to_input["input"] = input_vecs[i]
if model.is_attn:
to_input["word"] = word
to_input["context"] = context
if model.params["use_ch"]:
to_input["CH_word"] = word
if model.params["use_seed"]:
to_input["prefix"] = word
else:
to_input["prefix"] = BOS
print("Word: {0}".format(word))
print("Context: {0}".format(context))
print(generate(**to_input))