forked from blei-lab/deep-exponential-families
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdef_model.hpp
158 lines (127 loc) · 3.74 KB
/
def_model.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#pragma once
#include <gsl/gsl_rng.h>
#include "utils.hpp"
#include "serialization.hpp"
#include "random.hpp"
#include "def.hpp"
#include "def_layer.hpp"
#include "def_y_layer.hpp"
#include "link_function.hpp"
#include "def_data.hpp"
class DEFModel {
private:
pt::ptree ptree;
shared_ptr<DEFData> def_data;
vector<GSLRandom*> vec_rng;
int iteration;
int layers;
int batch_st;
// train | test
string model_type;
int n_examples, n_dim_y, n_samples, n_dim_z_1;
shared_ptr<ofstream> log_file;
string data_file;
bool exp_fam_mode;
DEF def;
shared_ptr<DEFYLayer> y_layer;
// for exponential distribution. q_w_obs - q_w_obs_b
shared_ptr<DEFPriorLayer> w_obs_layer, w_obs_layer_b;
shared_ptr<InferenceFactorizedLayer> q_w_obs_layer, q_w_obs_layer_b;
struct TrainStats {
int iteration;
shared_ptr<PredictionStats> prediction_stats;
arma::vec lp_y;
BBVIStats w_obs_layer;
arma::vec lq_w_obs;
arma::vec lp_w_obs;
DEF::TrainStats def_stats;
TrainStats(int iteration, int layers, int samples) : iteration(iteration),
lp_y(samples), lq_w_obs(samples), lp_w_obs(samples),
def_stats(iteration, layers, samples) {}
void fill_for_print(const string& model_type) const;
mutable vector<const BBVIStats*> w_for_print;
mutable vector<const arma::vec*> lp_w_for_print;
mutable vector<const arma::vec*> lq_w_for_print;
};
void print_stats(const TrainStats& stats);
void log_stats(const TrainStats& stats, ofstream& of);
void log_stats_header(ofstream& of, const vector<string>& prediction_header);
public:
void set_full(bool full) const { def.set_full(full); }
DEFModel() : data_file("") {}
DEFModel(const string& data_file) : data_file(data_file) {}
DEFModel(const pt::ptree& ptree)
: ptree(ptree), data_file("") {
init();
}
void init();
~DEFModel() {
for (GSLRandom* r : vec_rng) {
delete r;
}
}
int get_iteration() const {
return iteration;
}
void train_model();
TrainStats train_batch(const ExampleIds& example_ids);
TrainStats compute_log_likelihood();
void copy_iteration(const DEFModel& other) {
iteration = other.iteration-1;
}
void copy_w_params(const DEFModel& other) {
q_w_obs_layer->copy_params(&*other.q_w_obs_layer);
if (other.q_w_obs_layer_b == NULL)
q_w_obs_layer_b = NULL;
else
q_w_obs_layer_b->copy_params(&*other.q_w_obs_layer_b);
def.copy_w_params(other.def);
}
friend class boost::serialization::access;
BOOST_SERIALIZATION_SPLIT_MEMBER();
template<class Archive>
void save(Archive& ar, const unsigned int) const {
ar & ptree;
ar & iteration;
ar & batch_st;
for(int i=0; i<n_samples; ++i) {
ar & *vec_rng[i];
}
ar & def;
InferenceFactorizedLayer* lp = q_w_obs_layer.get();
ar & lp;
if (q_w_obs_layer_b != NULL) {
InferenceFactorizedLayer* lp_b = q_w_obs_layer_b.get();
ar & lp_b;
}
}
template<class Archive>
void load(Archive& ar, const unsigned int) {
ar & ptree;
init();
ar & iteration;
ar & batch_st;
for(int i=0; i<n_samples; ++i) {
ar & *vec_rng[i];
}
ar & def;
InferenceFactorizedLayer* lp;
ar & lp;
q_w_obs_layer.reset(lp);
if (q_w_obs_layer_b != NULL) {
InferenceFactorizedLayer* lp_b;
ar & lp_b;
q_w_obs_layer_b.reset(lp_b);
}
}
void load_part(shared_ptr<DEFModel> part_model, int k) {
assert(k >= 1);
w_obs_layer = part_model->w_obs_layer;
w_obs_layer_b = part_model->w_obs_layer_b;
q_w_obs_layer = part_model->q_w_obs_layer;
q_w_obs_layer_b = part_model->q_w_obs_layer_b;
def.load_part(part_model->def, k);
}
void save_params(const string& fname) const;
void load_params(const string& fname);
};