Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1187 Make set_initial_flows() of IDE model usable with and without age resolution #1188

17 changes: 16 additions & 1 deletion cpp/examples/ide_initialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "ide_secir/simulation.h"
#include "ide_secir/parameters_io.h"
#include "memilio/config.h"
#include "memilio/io/epi_data.h"
#include "memilio/utils/time_series.h"
#include "memilio/utils/date.h"
#include "memilio/math/eigen.h"
Expand Down Expand Up @@ -88,7 +89,21 @@ int main(int argc, char** argv)
}
else {
// Use the real data for initialization.
auto status = mio::isecir::set_initial_flows(model, dt, filename, mio::Date(2020, 12, 24));
// Here we assume that the file contains data without age resolution, hence we use read_confirmed_cases_noage()
// for reading the data and mio::ConfirmedCasesNoAgeEntry as EntryType in set_initial_flows().

auto status_read_data = mio::read_confirmed_cases_noage(filename);
if (!status_read_data) {
std::cout << "Error: " << status_read_data.error().formatted_message();
return -1;
}

std::vector<mio::ConfirmedCasesNoAgeEntry> rki_data = status_read_data.value();
mio::CustomIndexArray<ScalarType, mio::AgeGroup> scale_confirmed_cases =
mio::CustomIndexArray<ScalarType, mio::AgeGroup>(mio::AgeGroup(num_agegroups), 1.);

auto status = mio::isecir::set_initial_flows<mio::ConfirmedCasesNoAgeEntry>(
model, dt, rki_data, mio::Date(2020, 12, 24), scale_confirmed_cases);
if (!status) {
std::cout << "Error: " << status.error().formatted_message();
return -1;
Expand Down
1 change: 0 additions & 1 deletion cpp/models/ide_secir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ add_library(ide_secir
simulation.cpp
parameters.h
parameters_io.h
parameters_io.cpp
)
target_link_libraries(ide_secir PUBLIC memilio)
target_include_directories(ide_secir PUBLIC
Expand Down
23 changes: 21 additions & 2 deletions cpp/models/ide_secir/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ Model::Model(TimeSeries<ScalarType>&& transitions_init, CustomIndexArray<ScalarT
, m_num_agegroups{num_agegroups}

{
// Assert that input arguments for the total population have the correct size regarding
// age groups.
assert((size_t)m_N.size() == m_num_agegroups);

if (transitions.get_num_time_points() > 0) {
// Add first time point in populations according to last time point in transitions which is where we start
// Add first time point in m_populations according to last time point in m_transitions which is where we start
// the simulation.
populations.add_time_point<Eigen::VectorX<ScalarType>>(
transitions.get_last_time(),
Expand All @@ -71,7 +75,7 @@ bool Model::check_constraints(ScalarType dt) const
{

if (!((size_t)transitions.get_num_elements() == (size_t)InfectionTransition::Count * m_num_agegroups)) {
log_error("A variable given for model construction is not valid. Number of elements in vector of"
log_error("A variable given for model construction is not valid. Number of elements in vector of "
"transitions does not match the required number.");
return true;
}
Expand Down Expand Up @@ -120,6 +124,21 @@ bool Model::check_constraints(ScalarType dt) const
return true;
}

if ((size_t)total_confirmed_cases.size() > 0 && (size_t)total_confirmed_cases.size() != m_num_agegroups) {
log_error("Initialization failed. Number of elements in total_confirmed_cases does not match the number "
"of age groups.");
return true;
}

if ((size_t)total_confirmed_cases.size() > 0) {
for (AgeGroup group = AgeGroup(0); group < AgeGroup(m_num_agegroups); ++group) {
if (total_confirmed_cases[group] < 0) {
log_error("Initialization failed. One or more value of total_confirmed_cases is less than zero.");
return true;
}
}
}

return parameters.check_constraints();
}

Expand Down
24 changes: 22 additions & 2 deletions cpp/models/ide_secir/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "memilio/config.h"
#include "memilio/epidemiology/age_group.h"
#include "memilio/utils/custom_index_array.h"
#include "memilio/utils/date.h"
#include "memilio/utils/time_series.h"

#include "vector"
Expand All @@ -34,6 +35,13 @@ namespace mio
{
namespace isecir
{
// Forward declaration of friend classes/functions of Model.
class Model;
class Simulation;
template <typename EntryType>
IOResult<void> set_initial_flows(Model& model, const ScalarType dt, const std::vector<EntryType> rki_data,
const Date date, const CustomIndexArray<ScalarType, AgeGroup> scale_confirmed_cases);

class Model
{
using ParameterSet = Parameters;
Expand Down Expand Up @@ -130,6 +138,16 @@ class Model
return m_initialization_method;
}

/**
* @brief Getter for number of age groups.
annawendler marked this conversation as resolved.
Show resolved Hide resolved
*
* @return Returns number of age groups.
*/
size_t get_num_agegroups() const
{
return m_num_agegroups;
}

/**
* @brief Setter for the tolerance used to calculate the maximum support of the TransitionDistributions.
*
Expand Down Expand Up @@ -358,8 +376,10 @@ class Model
friend class Simulation;
// In set_initial_flows(), we compute initial flows based on RKI data using the (private) compute_flow() function
// which is why it is defined as a friend function.
friend IOResult<void> set_initial_flows(Model& model, ScalarType dt, std::string const& path, Date date,
ScalarType scale_confirmed_cases);
template <typename EntryType>
friend IOResult<void> set_initial_flows(Model& model, const ScalarType dt, const std::vector<EntryType> rki_data,
const Date date,
const CustomIndexArray<ScalarType, AgeGroup> scale_confirmed_cases);
};

} // namespace isecir
Expand Down
Loading