-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathadam_poisson_2d.py
104 lines (82 loc) · 2.55 KB
/
adam_poisson_2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Adam Optimization.
Two dimensional Poisson equation example. Solution given by
u(x,y) = sin(pi*x) * sin(py*y).
"""
import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit
import optax
from ngrad.models import init_params, mlp
from ngrad.domains import Square, SquareBoundary
from ngrad.integrators import DeterministicIntegrator
from ngrad.utility import laplace
jax.config.update("jax_enable_x64", True)
# random seed
seed = 0
# domains
interior = Square(1.)
boundary = SquareBoundary(1.)
# integrators
interior_integrator = DeterministicIntegrator(interior, 30)
boundary_integrator = DeterministicIntegrator(boundary, 30)
eval_integrator = DeterministicIntegrator(interior, 200)
# model
activation = lambda x : jnp.tanh(x)
layer_sizes = [2, 32, 1]
params = init_params(layer_sizes, random.PRNGKey(seed))
model = mlp(activation)
v_model = vmap(model, (None, 0))
# solution
@jit
def u_star(x):
return jnp.prod(jnp.sin(jnp.pi * x))
# rhs
@jit
def f(x):
return 2. * jnp.pi**2 * u_star(x)
# compute residual
laplace_model = lambda params: laplace(lambda x: model(params, x))
residual = lambda params, x: (laplace_model(params)(x) + f(x))**2.
v_residual = jit(vmap(residual, (None, 0)))
# loss
@jit
def interior_loss(params):
return interior_integrator(lambda x: v_residual(params, x))
@jit
def boundary_loss(params):
return boundary_integrator(lambda x: v_model(params, x)**2)
@jit
def loss(params):
return interior_loss(params) + boundary_loss(params)
# errors
error = lambda x: model(params, x) - u_star(x)
v_error = vmap(error, (0))
v_error_abs_grad = vmap(
lambda x: jnp.dot(grad(error)(x), grad(error)(x))**0.5
)
def l2_norm(f, integrator):
return integrator(lambda x: (f(x))**2)**0.5
# optimizer settings
exponential_decay = optax.exponential_decay(
init_value=0.001,
transition_steps=10000,
transition_begin=15000,
decay_rate=0.1,
end_value=0.0000001
)
optimizer = optax.adam(learning_rate=exponential_decay)
opt_state = optimizer.init(params)
# adam gradient descent with line search
for iteration in range(200000):
grads = grad(loss)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if iteration % 1000 == 0:
# errors
l2_error = l2_norm(v_error, eval_integrator)
h1_error = l2_error + l2_norm(v_error_abs_grad, eval_integrator)
print(
f'Adam Iteration: {iteration} with loss: {loss(params)} with error '
f'L2: {l2_error} and error H1: {h1_error}.'
)