Skip to content

Commit

Permalink
testing #19 for gpr aep -- not working
Browse files Browse the repository at this point in the history
  • Loading branch information
thangbui committed Jun 12, 2017
1 parent cf7910c commit 68b6886
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
17 changes: 10 additions & 7 deletions examples/gpr_aep_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .datautils import step, spiral
from .context import aep

def run_regression_1D():
def run_regression_1D(nat_param=True):
np.random.seed(42)

print "create dataset ..."
Expand Down Expand Up @@ -37,10 +37,12 @@ def plot(m):
# inference
print "create model and optimize ..."
M = 20
model = aep.SGPR(X, Y, M, lik='Gaussian')
model.optimise(method='L-BFGS-B', alpha=0.1, maxiter=50000)
plot(model)
plt.show()
model = aep.SGPR(X, Y, M, lik='Gaussian', nat_param=nat_param)
model.update_hypers(model.init_hypers(Y))
print model.objective_function(model.get_hypers(), N, 0.1)
model.optimise(method='L-BFGS-B', alpha=0.1, maxiter=1)
# plot(model)
# plt.show()


def run_banana():
Expand Down Expand Up @@ -307,11 +309,12 @@ def run_boston():


if __name__ == '__main__':
# run_regression_1D()
run_regression_1D(True)
run_regression_1D(False)
# run_banana()
# run_step_1D()
# run_spiral()
# run_boston()

# run_regression_1D_stoc()
run_banana_stoc()
# run_banana_stoc()
4 changes: 0 additions & 4 deletions geepee/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,6 @@ def compute_posterior_grad_u(self, dmu, dSu):

return deta1_R, deta2, dKuuinv




def init_hypers(self, x_train=None, key_suffix=''):
"""Summary
Expand Down Expand Up @@ -575,7 +572,6 @@ def init_hypers(self, x_train=None, key_suffix=''):
triu_ind = np.triu_indices(M)
diag_ind = np.diag_indices(M)
R[diag_ind] = np.log(R[diag_ind])
np.log(R[diag_ind])
eta1_d = R[triu_ind].reshape((M * (M + 1) / 2,))
eta2_d = theta2.reshape((M,))
eta1_R[d, :] = eta1_d
Expand Down
2 changes: 1 addition & 1 deletion tests/test_grads_aep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5135,7 +5135,7 @@ def plot_gpssm_linear_aep_gaussian_stochastic():
# plot_gplvm_aep_probit_stochastic()
# plot_gplvm_aep_gaussian_stochastic()

test_gpr_aep_gaussian(True)
# test_gpr_aep_gaussian(True)
test_gpr_aep_gaussian(False)
# test_gpr_aep_probit()
# test_gpr_aep_gaussian_scipy()
Expand Down

0 comments on commit 68b6886

Please sign in to comment.