-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathowlqn.go
93 lines (78 loc) · 2.35 KB
/
owlqn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package owlqn
import (
"container/list"
"fmt"
"math"
)
const numItersToAvg int = 5
type TerminationCriterion struct {
prevVals *list.List
}
func (termCrit *TerminationCriterion) GetValue(state *OptimizerState) float32 {
var retval float32 = math.MaxFloat32
if termCrit.prevVals.Len() >= numItersToAvg {
var prevVal float32 = termCrit.prevVals.Front().Value.(float32)
if termCrit.prevVals.Len() == 10 {
firstContainerItem := termCrit.prevVals.Front()
if firstContainerItem != nil {
_ = termCrit.prevVals.Remove(firstContainerItem)
}
}
var averageImprovement float32 = (prevVal - state.GetValue()) / float32(termCrit.prevVals.Len())
var relAvgImpr float32 = averageImprovement / abs(state.GetValue())
fmt.Printf("relAvgImpr:%f\n", relAvgImpr)
retval = relAvgImpr
} else {
fmt.Println(" (wait for five iters) ")
}
termCrit.prevVals.PushBack(state.GetValue())
return retval
}
type OWLQN struct {
quiet bool
termCrit *TerminationCriterion
}
func NewTerminationCriterion() *TerminationCriterion {
return &TerminationCriterion{
prevVals: list.New(),
}
}
func NewOWLQN(quiet bool) *OWLQN {
return &OWLQN{
quiet: quiet,
termCrit: NewTerminationCriterion(),
}
}
func (opt *OWLQN) Minimize(f costfunction, init []float32, result []float32, l1weight float32, tol float32, m int) {
state := NewOptimizerState(f, m, init, l1weight, opt.quiet)
//fmt.Println("test:",state.x,state.grad,state.newGrad)
if !opt.quiet {
fmt.Printf("Optimizing function of %d variables with OWLQN parameters:\n", state.dim)
fmt.Printf("l1 regularization weight:%f.\n", l1weight)
fmt.Printf("L-BFGS memory parameter (m):%d\n", m)
fmt.Printf("Convergence tolerance:%f\n", tol)
fmt.Printf("Iter n: new_value\n")
fmt.Printf("Iter 0: %f\n", state.value)
}
for true {
if !opt.quiet {
fmt.Printf("Iter %d\n", state.iter)
}
state.UpdateDir()
//fmt.Println("BackTrackingLineSearch")
//fmt.Println("dir:",state.dir)
//fmt.Println("grad:",state.grad)
//fmt.Println("newX:",state.newX)
state.BackTrackingLineSearch()
var termCritVal float32 = opt.termCrit.GetValue(state)
if !opt.quiet {
//fmt.Printf("Iter %d: %f\n", state.iter, state.value)
fmt.Printf("state.value:%f\n", state.value)
}
if termCritVal < tol {
break
}
state.Shift()
}
DeepCopy(result,state.newX)
}