-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
50 lines (43 loc) · 1.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#!/usr/bin/env python3
from __future__ import division # no need for python3, but just in case used w/ python2
import sys
import time
from svector import svector
def read_from(textfile):
for line in open(textfile):
label, words = line.strip().split("\t")
yield (1 if label=="+" else -1, words.split())
def make_vector(words):
v = svector()
#add bias
v[0]=1
for word in words:
v[word] += 1
return v
def test(devfile, model):
tot, err = 0, 0
for i, (label, words) in enumerate(read_from(devfile), 1): # note 1...|D|
err += label * (model.dot(make_vector(words))) <= 0
return err/i # i is |D| now
def train(trainfile, devfile, epochs=2):
t = time.time()
best_err = 1.
model = svector()
c = 0
wa = svector()
for it in range(1, epochs+1):
updates = 0
for i, (label, words) in enumerate(read_from(trainfile), 1): # label is +1 or -1
sent = make_vector(words)
if label * (model.dot(sent)) <= 0:
updates += 1
model += label * sent
wa += c*label*sent
model = c*model - wa
c += 1
dev_err = test(devfile, model)
best_err = min(best_err, dev_err)
print("epoch %d, update %.1f%%, dev %.1f%%" % (it, updates / i * 100, dev_err * 100))
print("best dev err %.1f%%, |w|=%d, time: %.1f secs" % (best_err * 100, len(model), time.time() - t))
if __name__ == "__main__":
train(sys.argv[1], sys.argv[2], 1)