-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathRecursiveNN.h
94 lines (69 loc) · 1.94 KB
/
RecursiveNN.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
/*
* RecursiveNN.h
*
* Created on: Mar 18, 2015
* Author: mszhang
*/
#ifndef SRC_RecursiveNN_H_
#define SRC_RecursiveNN_H_
#include "tensor.h"
#include "BiLayer.h"
#include "MyLib.h"
#include "Utiltensor.h"
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow::utils;
// Actually, we do not need such a class, BiLayer satisfies it
template<typename xpu>
class RecursiveNN {
public:
BiLayer<xpu> _rnn;
public:
RecursiveNN() {
}
inline void initial(int dimension, int seed = 0) {
_rnn.initial(dimension, dimension, dimension, true, seed, 0);
}
inline void initial(Tensor<xpu, 2, dtype> WL, Tensor<xpu, 2, dtype> WR, Tensor<xpu, 2, dtype> b) {
_rnn.initial(WL, WR, b, true);
}
inline void release() {
_rnn.release();
}
virtual ~RecursiveNN() {
// TODO Auto-generated destructor stub
}
inline dtype squarenormAll() {
dtype norm = _rnn.squarenormAll();
return norm;
}
inline void scaleGrad(dtype scale) {
_rnn.scaleGrad(scale);
}
public:
inline void ComputeForwardScore(Tensor<xpu, 2, dtype> xl, Tensor<xpu, 2, dtype> xr, Tensor<xpu, 2, dtype> y) {
y = 0.0;
_rnn.ComputeForwardScore(xl, xr, y);
}
//please allocate the memory outside here
inline void ComputeBackwardLoss(Tensor<xpu, 2, dtype> xl, Tensor<xpu, 2, dtype> xr, Tensor<xpu, 2, dtype> y, Tensor<xpu, 2, dtype> ly,
Tensor<xpu, 2, dtype> lxl, Tensor<xpu, 2, dtype> lxr, bool bclear = false) {
if (bclear){
lxl = 0.0; lxr = 0.0;
}
_rnn.ComputeBackwardLoss(xl, xr, y, ly, lxl, lxr);
}
inline void randomprint(int num) {
_rnn.randomprint(num);
}
inline void updateAdaGrad(dtype regularizationWeight, dtype adaAlpha, dtype adaEps) {
_rnn.updateAdaGrad(regularizationWeight, adaAlpha, adaEps);
}
void writeModel(LStream &outf) {
_rnn.writeModel(outf);
}
void loadModel(LStream &inf) {
_rnn.loadModel(inf);
}
};
#endif /* SRC_RecursiveNN_H_ */