Skip to content

Commit

Permalink
Correcting minor bug in dist functions and adding optional overwrite …
Browse files Browse the repository at this point in the history
…to read and add params
  • Loading branch information
gvegayon committed Feb 21, 2025
1 parent 82472d7 commit 331845c
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 26 deletions.
41 changes: 28 additions & 13 deletions epiworld.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
/* Versioning */
#define EPIWORLD_VERSION_MAJOR 0
#define EPIWORLD_VERSION_MINOR 7
#define EPIWORLD_VERSION_PATCH 0
#define EPIWORLD_VERSION_PATCH 1

static const int epiworld_version_major = EPIWORLD_VERSION_MAJOR;
static const int epiworld_version_minor = EPIWORLD_VERSION_MINOR;
Expand Down Expand Up @@ -6964,8 +6964,10 @@ class Model {
*
*/
///@{
epiworld_double add_param(epiworld_double initial_val, std::string pname);
void read_params(std::string fn);
epiworld_double add_param(
epiworld_double initial_val, std::string pname, bool overwrite = false
);
void read_params(std::string fn, bool overwrite = false);
epiworld_double get_param(epiworld_fast_uint k);
epiworld_double get_param(std::string pname);
// void set_param(size_t k, epiworld_double val);
Expand Down Expand Up @@ -9483,18 +9485,23 @@ inline void Model<TSeq>::print_state_codes() const
template<typename TSeq>
inline epiworld_double Model<TSeq>::add_param(
epiworld_double initial_value,
std::string pname
std::string pname,
bool overwrite
) {

if (parameters.find(pname) == parameters.end())
parameters[pname] = initial_value;
else if (!overwrite)
throw std::logic_error("The parameter " + pname + " already exists.");
else
parameters[pname] = initial_value;

return initial_value;

}

template<typename TSeq>
inline void Model<TSeq>::read_params(std::string fn)
inline void Model<TSeq>::read_params(std::string fn, bool overwrite)
{

std::ifstream paramsfile(fn);
Expand Down Expand Up @@ -9532,7 +9539,8 @@ inline void Model<TSeq>::read_params(std::string fn)
std::regex_replace(
match[1u].str(),
std::regex("^\\s+|\\s+$"),
"")
""),
overwrite
);

}
Expand Down Expand Up @@ -10567,15 +10575,18 @@ inline VirusToAgentFun<TSeq> distribute_virus_randomly(
idx.push_back(agent.get_id());

// Picking how many
size_t n = model->size();
int n = model->size();
int n_available = static_cast<int>(idx.size());
int n_to_sample;
if (prevalence_as_proportion)
{
n_to_sample = static_cast<int>(std::floor(prevalence * n));
n_to_sample = static_cast<int>(std::floor(
prevalence * static_cast< epiworld_double >(n)
));

if (n_to_sample == static_cast<int>(n))
n_to_sample--;
// Correcting for possible overflow
if (n_to_sample == (n + 1))
--n_to_sample;
}
else
{
Expand Down Expand Up @@ -11816,8 +11827,10 @@ inline ToolToAgentFun<TSeq> distribute_tool_randomly(
{
n_to_distribute = static_cast<int>(std::floor(prevalence * n));

if (n_to_distribute == n)
n_to_distribute--;
// Correcting for possible rounding errors
if (n_to_distribute == (n + 1))
--n_to_distribute;

}
else
{
Expand Down Expand Up @@ -12617,7 +12630,9 @@ inline EntityToAgentFun<TSeq> distribute_entity_randomly(
if (as_proportion)
{
n_to_sample = static_cast<int>(std::floor(prevalence * n));
if (n_to_sample > static_cast<int>(n))

// Correcting for possible overflow
if (n_to_sample == (static_cast<int>(n) + 1))
--n_to_sample;

} else
Expand Down
4 changes: 3 additions & 1 deletion include/epiworld/entity-distribute-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ inline EntityToAgentFun<TSeq> distribute_entity_randomly(
if (as_proportion)
{
n_to_sample = static_cast<int>(std::floor(prevalence * n));
if (n_to_sample > static_cast<int>(n))

// Correcting for possible overflow
if (n_to_sample == (static_cast<int>(n) + 1))
--n_to_sample;

} else
Expand Down
2 changes: 1 addition & 1 deletion include/epiworld/epiworld.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
/* Versioning */
#define EPIWORLD_VERSION_MAJOR 0
#define EPIWORLD_VERSION_MINOR 7
#define EPIWORLD_VERSION_PATCH 0
#define EPIWORLD_VERSION_PATCH 1

static const int epiworld_version_major = EPIWORLD_VERSION_MAJOR;
static const int epiworld_version_minor = EPIWORLD_VERSION_MINOR;
Expand Down
6 changes: 4 additions & 2 deletions include/epiworld/model-bones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,10 @@ class Model {
*
*/
///@{
epiworld_double add_param(epiworld_double initial_val, std::string pname);
void read_params(std::string fn);
epiworld_double add_param(
epiworld_double initial_val, std::string pname, bool overwrite = false
);
void read_params(std::string fn, bool overwrite = false);
epiworld_double get_param(epiworld_fast_uint k);
epiworld_double get_param(std::string pname);
// void set_param(size_t k, epiworld_double val);
Expand Down
12 changes: 9 additions & 3 deletions include/epiworld/model-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2035,18 +2035,23 @@ inline void Model<TSeq>::print_state_codes() const
template<typename TSeq>
inline epiworld_double Model<TSeq>::add_param(
epiworld_double initial_value,
std::string pname
std::string pname,
bool overwrite
) {

if (parameters.find(pname) == parameters.end())
parameters[pname] = initial_value;
else if (!overwrite)
throw std::logic_error("The parameter " + pname + " already exists.");
else
parameters[pname] = initial_value;

return initial_value;

}

template<typename TSeq>
inline void Model<TSeq>::read_params(std::string fn)
inline void Model<TSeq>::read_params(std::string fn, bool overwrite)
{

std::ifstream paramsfile(fn);
Expand Down Expand Up @@ -2084,7 +2089,8 @@ inline void Model<TSeq>::read_params(std::string fn)
std::regex_replace(
match[1u].str(),
std::regex("^\\s+|\\s+$"),
"")
""),
overwrite
);

}
Expand Down
6 changes: 4 additions & 2 deletions include/epiworld/tool-distribute-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ inline ToolToAgentFun<TSeq> distribute_tool_randomly(
{
n_to_distribute = static_cast<int>(std::floor(prevalence * n));

if (n_to_distribute == n)
n_to_distribute--;
// Correcting for possible rounding errors
if (n_to_distribute == (n + 1))
--n_to_distribute;

}
else
{
Expand Down
11 changes: 7 additions & 4 deletions include/epiworld/virus-distribute-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,18 @@ inline VirusToAgentFun<TSeq> distribute_virus_randomly(
idx.push_back(agent.get_id());

// Picking how many
size_t n = model->size();
int n = model->size();
int n_available = static_cast<int>(idx.size());
int n_to_sample;
if (prevalence_as_proportion)
{
n_to_sample = static_cast<int>(std::floor(prevalence * n));
n_to_sample = static_cast<int>(std::floor(
prevalence * static_cast< epiworld_double >(n)
));

if (n_to_sample == static_cast<int>(n))
n_to_sample--;
// Correcting for possible overflow
if (n_to_sample == (n + 1))
--n_to_sample;
}
else
{
Expand Down

0 comments on commit 331845c

Please sign in to comment.