-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmodel.go
179 lines (148 loc) · 3.34 KB
/
model.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
package randtxt
import (
"errors"
"math/rand"
"strings"
"github.com/pboyd/markov"
)
// Model steps through the chain word by word and reports probabilities.
//
// Model is a lower-level interface. Generator is the recommended way to
// generate text.
type Model struct {
chain markov.Chain
current string
past []string
}
// NewModel initializes a model from a chain. "seed" is used as the starting
// point. If "seed" is blank a random seed is chosen.
func NewModel(chain markov.Chain, seed string) (*Model, error) {
if seed == "" {
var err error
seed, err = randomSeed(chain)
if err != nil {
return nil, err
}
}
// Make sure the seed exists.
_, err := chain.Find(seed)
if err != nil {
return nil, err
}
past := strings.Split(seed, " ")
return &Model{
chain: chain,
past: past,
}, nil
}
func randomSeed(chain markov.Chain) (string, error) {
size, err := inspectChain(chain)
if err != nil {
return "", nil
}
for {
raw, err := markov.Random(chain)
if err != nil {
return "", err
}
seed := strings.Split(raw.(string), " ")
if len(seed) == size {
return raw.(string), nil
}
}
}
// Current returns the word and POS tag that the model is currently at..
func (m *Model) Current() Tag {
return parseTag(m.current)
}
// NextTags returns a list of tags that could be next along with their
// probabilities.
func (m *Model) NextTags() ([]TagProbability, error) {
links, err := m.nextLinks()
if err != nil {
return nil, err
}
tp := make([]TagProbability, len(links))
for i, link := range links {
raw, err := m.chain.Get(link.ID)
if err != nil {
return nil, err
}
tp[i] = TagProbability{
raw: raw.(string),
Probability: link.Probability,
}
}
return tp, nil
}
// Step advances the model.
func (m *Model) Step() error {
next, err := m.pickNext()
if err != nil {
return err
}
// Shift the past elements to the left to make room for the new word.
size := len(m.past)
copy(m.past, m.past[1:size])
m.past[size-1] = next
m.current = next
return nil
}
func (m *Model) pickNext() (string, error) {
links, err := m.nextLinks()
if err != nil {
return "", err
}
index := rand.Float64()
var passed float64
for _, link := range links {
passed += link.Probability
if passed > index {
raw, err := m.chain.Get(link.ID)
if err != nil {
return "", err
}
return raw.(string), nil
}
}
return "", errors.New("failed")
}
func (m *Model) nextLinks() ([]markov.Link, error) {
id, err := m.chain.Find(strings.Join(m.past, " "))
if err != nil {
return nil, err
}
links, err := m.chain.Links(id)
if err != nil {
return nil, err
}
if len(links) == 0 {
// If the chain ends in a unique phrase the chain will end.
// Restart it at a random point. This isn't ideal, since it may
// be mid-sentence.
err = m.reseed()
if err != nil {
return nil, err
}
return m.nextLinks()
}
return links, nil
}
func (m *Model) reseed() error {
seed, err := randomSeed(m.chain)
if err != nil {
return err
}
m.past = strings.Split(seed, " ")
return nil
}
// TagProbability contains a Tag and the probability that it will be used next.
// Returned in a slice from Model.NextTags.
type TagProbability struct {
raw string
Probability float64
}
// Tag parses the tag in it's raw form and returns it.
func (tp *TagProbability) Tag() Tag {
return parseTag(tp.raw)
}