Skip to content

Commit

Permalink
add second derivatives for squar_sin_exp (#561)
Browse files Browse the repository at this point in the history
* add second derivatives for squar_sin_exp

* Uncomment exception and tests

* Small mistake fix
  • Loading branch information
NicolasJeanGonel authored May 17, 2024
1 parent 3f31c93 commit f8fe345
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 22 deletions.
1 change: 1 addition & 0 deletions smt/surrogate_models/krg_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,7 @@ def _predict_derivatives(self, x, kx):

dx = differences(x, Y=self.X_norma.copy())
d = self._componentwise_distance(dx)

dd = self._componentwise_distance(
dx, theta=self.optimal_theta, return_derivative=True
)
Expand Down
1 change: 0 additions & 1 deletion smt/surrogate_models/tests/test_krg_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def _check_prediction_derivatives(self, sm):
]
)
total_error = np.sum(pred_errors**2)

np.testing.assert_allclose(total_error, 0, atol=5e-3)

y_predicted = sm.predict_values(x_valid)
Expand Down
48 changes: 33 additions & 15 deletions smt/surrogate_models/tests/test_krg_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
matern52,
pow_exp,
squar_exp,
squar_sin_exp,
)
from smt.utils.misc import standardization
from smt.utils.sm_test_case import SMTestCase
Expand Down Expand Up @@ -50,7 +51,15 @@ def setUp(self):
"matern52",
"squar_sin_exp",
]
corr_def = [pow_exp, abs_exp, squar_exp, act_exp, matern32, matern52]
corr_def = [
pow_exp,
abs_exp,
squar_exp,
act_exp,
matern32,
matern52,
squar_sin_exp,
]
power_val = {
"pow_exp": 1.9,
"abs_exp": 1.0,
Expand Down Expand Up @@ -103,27 +112,31 @@ def test_corr_derivatives(self):
self.X.shape[1],
self.power_val[self.corr_str[ind]],
)
if corr == squar_sin_exp:
theta = self.random.rand(4)
else:
theta = self.theta

k = corr(self.theta, D)
k = corr(theta, D)
K = np.eye(self.X.shape[0])
K[self.ij[:, 0], self.ij[:, 1]] = k[:, 0]
K[self.ij[:, 1], self.ij[:, 0]] = k[:, 0]
grad_norm_all = []
diff_norm_all = []
ind_theta = []
for i, theta_i in enumerate(self.theta):
eps_theta = np.zeros(self.theta.shape)
for i, theta_i in enumerate(theta):
eps_theta = np.zeros(theta.shape)
eps_theta[i] = self.eps

k_dk = corr(self.theta + eps_theta, D)
k_dk = corr(theta + eps_theta, D)

K_dk = np.eye(self.X.shape[0])
K_dk[self.ij[:, 0], self.ij[:, 1]] = k_dk[:, 0]
K_dk[self.ij[:, 1], self.ij[:, 0]] = k_dk[:, 0]

grad_eps = (K_dk - K) / self.eps

dk = corr(self.theta, D, grad_ind=i)
dk = corr(theta, D, grad_ind=i)
dK = np.zeros((self.X.shape[0], self.X.shape[0]))
dK[self.ij[:, 0], self.ij[:, 1]] = dk[:, 0]
dK[self.ij[:, 1], self.ij[:, 0]] = dk[:, 0]
Expand All @@ -144,34 +157,38 @@ def test_corr_hessian(self):
self.power_val[self.corr_str[ind]],
)

if corr == squar_sin_exp:
theta = self.random.rand(4)
else:
theta = self.theta

grad_norm_all = []
diff_norm_all = []
for i, theta_i in enumerate(self.theta):
k = corr(self.theta, D, grad_ind=i)
for i, theta_i in enumerate(theta):
k = corr(theta, D, grad_ind=i)

K = np.eye(self.X.shape[0])
K[self.ij[:, 0], self.ij[:, 1]] = k[:, 0]
K[self.ij[:, 1], self.ij[:, 0]] = k[:, 0]
for j, omega_j in enumerate(self.theta):
eps_omega = np.zeros(self.theta.shape)
for j, omega_j in enumerate(theta):
eps_omega = np.zeros(theta.shape)
eps_omega[j] = self.eps

k_dk = corr(self.theta + eps_omega, D, grad_ind=i)
k_dk = corr(theta + eps_omega, D, grad_ind=i)

K_dk = np.eye(self.X.shape[0])
K_dk[self.ij[:, 0], self.ij[:, 1]] = k_dk[:, 0]
K_dk[self.ij[:, 1], self.ij[:, 0]] = k_dk[:, 0]

grad_eps = (K_dk - K) / self.eps

dk = corr(self.theta, D, grad_ind=i, hess_ind=j)
dk = corr(theta, D, grad_ind=i, hess_ind=j)
dK = np.zeros((self.X.shape[0], self.X.shape[0]))
dK[self.ij[:, 0], self.ij[:, 1]] = dk[:, 0]
dK[self.ij[:, 1], self.ij[:, 0]] = dk[:, 0]

grad_norm_all.append(np.linalg.norm(dK))
diff_norm_all.append(np.linalg.norm(grad_eps))

self.assert_error(
np.array(grad_norm_all), np.array(diff_norm_all), 1e-5, 1e-5
) # from utils/smt_test_case.py
Expand Down Expand Up @@ -225,14 +242,15 @@ def test_likelihood_derivatives(self):
) # from utils/smt_test_case.py

def test_likelihood_hessian(self):
self.setUp()
for corr_str in [
"squar_sin_exp",
"pow_exp",
"abs_exp",
"squar_exp",
"act_exp",
"matern32",
"matern52",
# "squar_sin_exp", # Yet to implement
]: # For every kernel
for poly_str in ["constant", "linear", "quadratic"]: # For every method
if corr_str == "squar_sin_exp":
Expand Down Expand Up @@ -283,7 +301,7 @@ def test_likelihood_hessian(self):

def test_variance_derivatives(self):
for corr_str in [
# "squar_sin_exp", ### Yet to implement
# "squar_sin_exp", ### Yet to implement
"abs_exp",
"squar_exp",
"matern32",
Expand Down
89 changes: 83 additions & 6 deletions smt/utils/kriging.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,6 @@ def squar_sin_exp(theta, d, grad_ind=None, hess_ind=None, derivative_params=None
"""

r = np.zeros((d.shape[0], 1))

# Construct/split the correlation matrix
i, nb_limit = 0, int(1e4)
while i * nb_limit <= d.shape[0]:
Expand All @@ -705,10 +704,16 @@ def squar_sin_exp(theta, d, grad_ind=None, hess_ind=None, derivative_params=None
)
)
i += 1
kernel = r.copy()

i = 0
if grad_ind is not None:
cut = int(len(theta) / 2)
if (
hess_ind is not None and grad_ind >= cut and hess_ind < cut
): # trick to use the symetry of the hessian when the hessian is asked
grad_ind, hess_ind = hess_ind, grad_ind

if grad_ind < cut:
grad_ind2 = cut + grad_ind
while i * nb_limit <= d.shape[0]:
Expand Down Expand Up @@ -741,11 +746,83 @@ def squar_sin_exp(theta, d, grad_ind=None, hess_ind=None, derivative_params=None
i = 0
if hess_ind is not None:
cut = int(len(theta) / 2)
# if hess_ind == grad_ind :
# else :
raise ValueError(
"Second derivatives for ExpSinSquared not available yet (to implement)."
)
if grad_ind < cut and hess_ind < cut:
hess_ind2 = cut + hess_ind
while i * nb_limit <= d.shape[0]:
r[i * nb_limit : (i + 1) * nb_limit, 0] = (
-(
np.sin(
theta_array[0][hess_ind2]
* d[i * nb_limit : (i + 1) * nb_limit, hess_ind]
)
** 2
)
* r[i * nb_limit : (i + 1) * nb_limit, 0]
)
i += 1
elif grad_ind >= cut and hess_ind >= cut:
hess_ind2 = hess_ind - cut
if grad_ind == hess_ind:
while i * nb_limit <= d.shape[0]:
r[i * nb_limit : (i + 1) * nb_limit, 0] = (
-2
* theta_array[0][hess_ind2]
* d[i * nb_limit : (i + 1) * nb_limit, hess_ind2] ** 2
* np.cos(
2
* theta_array[0][grad_ind]
* d[i * nb_limit : (i + 1) * nb_limit, hess_ind2]
)
* kernel[i * nb_limit : (i + 1) * nb_limit, 0]
- theta_array[0][hess_ind2]
* d[i * nb_limit : (i + 1) * nb_limit, hess_ind2]
* np.sin(
2
* d[i * nb_limit : (i + 1) * nb_limit, hess_ind2]
* theta_array[0][hess_ind]
)
* r[i * nb_limit : (i + 1) * nb_limit, 0]
)
i += 1
else:
while i * nb_limit <= d.shape[0]:
r[i * nb_limit : (i + 1) * nb_limit, 0] = (
-theta_array[0][hess_ind2]
* d[i * nb_limit : (i + 1) * nb_limit, hess_ind2]
* np.sin(
2
* d[i * nb_limit : (i + 1) * nb_limit, hess_ind2]
* theta_array[0][hess_ind]
)
* r[i * nb_limit : (i + 1) * nb_limit, 0]
)
i += 1
elif grad_ind < cut and hess_ind >= cut:
hess_ind2 = hess_ind - cut
while i * nb_limit <= d.shape[0]:
r[i * nb_limit : (i + 1) * nb_limit, 0] = (
-theta_array[0][hess_ind2]
* d[i * nb_limit : (i + 1) * nb_limit, hess_ind2]
* np.sin(
2
* d[i * nb_limit : (i + 1) * nb_limit, hess_ind2]
* theta_array[0][hess_ind]
)
* r[i * nb_limit : (i + 1) * nb_limit, 0]
)
if hess_ind2 == grad_ind:
r[i * nb_limit : (i + 1) * nb_limit, 0] += (
-d[i * nb_limit : (i + 1) * nb_limit, hess_ind2]
* np.sin(
2
* d[i * nb_limit : (i + 1) * nb_limit, hess_ind2]
* theta_array[0][hess_ind]
)
* kernel[i * nb_limit : (i + 1) * nb_limit, 0]
)

i += 1
i = 0

if derivative_params is not None:
raise ValueError(
Expand Down

0 comments on commit f8fe345

Please sign in to comment.