-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtests.py
94 lines (67 loc) · 2.74 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import unittest
import numpy as np
import pybamm as pb
from map_serial import solve_w_pool_serial
from pool import solve_w_pool
from serial import solve_serial
from sharedarray import solve_w_SharedArray
from solve import solve_w_pool_solve
def current_function(t):
return pb.InputParameter("Current")
def get_initial_solution(model, t_eval, inputs):
solver = pb.CasadiSolver()
return solver.solve(model, t_eval, inputs=inputs)
class TestEnsembleSimulation(unittest.TestCase):
def setUp(self):
pb.set_logging_level("WARNING")
# load model
self.model = pb.lithium_ion.SPMe()
# create geometry
geometry = self.model.default_geometry
# load parameter values and process model and geometry
param = self.model.default_parameter_values
param.update(
{
"Current function [A]": current_function,
}
)
param.update({"Current": "[input]"}, check_already_exists=False)
param.process_model(self.model)
param.process_geometry(geometry)
# set mesh
mesh = pb.Mesh(
geometry, self.model.default_submesh_types, self.model.default_var_pts
)
# discretise self.model
disc = pb.Discretisation(mesh, self.model.default_spatial_methods)
disc.process_model(self.model)
self.sol_init = get_initial_solution(
self.model, np.linspace(0, 1, 2), {"Current": 0.67}
)
self.Nsteps = 10
self.dt = 1
self.Nspm = 8
expected_y_flat = np.fromfile("ref/ref_solution.bin")
Npoint = self.sol_init.y.shape[0]
Nspm = len(expected_y_flat) // Npoint
self.expected_y = expected_y_flat.reshape((Npoint, Nspm))
def test_SharedArray(self):
y, t = solve_w_SharedArray(
self.model, self.sol_init, self.Nsteps, self.dt, self.Nspm
)
np.testing.assert_almost_equal(y, self.expected_y, decimal=5)
def test_Pool(self):
y, t = solve_w_pool(self.model, self.sol_init, self.Nsteps, self.dt, self.Nspm)
np.testing.assert_almost_equal(y, self.expected_y, decimal=5)
def test_Pool_solve(self):
y, t = solve_w_pool_solve(self.model, self.Nsteps, self.dt, self.Nspm)
# np.testing.assert_almost_equal(y, self.expected_y, decimal=5)
self.assertTrue(True)
def test_Serial(self):
y, t = solve_serial(self.model, self.sol_init, self.Nsteps, self.dt, self.Nspm)
np.testing.assert_almost_equal(y, self.expected_y, decimal=5)
def test_Pool_serial(self):
y, t = solve_w_pool(
self.model, self.sol_init, self.Nsteps, self.dt, self.Nspm, serial=True
)
np.testing.assert_almost_equal(y, self.expected_y, decimal=5)