Skip to content

Commit

Permalink
Allowing to normalize or not probs
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Feb 20, 2025
1 parent 17a4816 commit 2a6700a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 14 deletions.
37 changes: 30 additions & 7 deletions epiworld.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3272,7 +3272,8 @@ class DataBase {
* @return std::vector< epiworld_double >
*/
std::vector< epiworld_double > transition_probability(
bool print = true
bool print = true,
bool normalize = true
) const;

bool operator==(const DataBase<TSeq> & other) const;
Expand Down Expand Up @@ -4493,7 +4494,8 @@ inline void DataBase<TSeq>::reproductive_number(

template<typename TSeq>
inline std::vector< epiworld_double > DataBase<TSeq>::transition_probability(
bool print
bool print,
bool normalize
) const {

auto states_labels = model->get_states();
Expand All @@ -4507,7 +4509,8 @@ inline std::vector< epiworld_double > DataBase<TSeq>::transition_probability(

for (size_t s_i = 0; s_i < n_state; ++s_i)
{
epiworld_double daily_total = hist_total_counts[(t - 1) * n_state + s_i];
epiworld_double daily_total =
hist_total_counts[(t - 1) * n_state + s_i];

if (daily_total == 0)
continue;
Expand Down Expand Up @@ -4542,10 +4545,13 @@ inline std::vector< epiworld_double > DataBase<TSeq>::transition_probability(

}

for (size_t s_i = 0; s_i < n_state; ++s_i)
if (normalize)
{
for (size_t s_j = 0; s_j < n_state; ++s_j)
res[s_i + s_j * n_state] /= days_to_include[s_i];
for (size_t s_i = 0; s_i < n_state; ++s_i)
{
for (size_t s_j = 0; s_j < n_state; ++s_j)
res[s_i + s_j * n_state] /= days_to_include[s_i];
}
}

if (print)
Expand All @@ -4557,6 +4563,21 @@ inline std::vector< epiworld_double > DataBase<TSeq>::transition_probability(
nchar = l.length();

std::string fmt = " - %-" + std::to_string(nchar) + "s";

std::string fmt_entry = " % 4.2f";
if (!normalize)
{
nchar = 0u;
for (auto & l: res)
{
std::string tmp = std::to_string(l);
if (tmp.length() > nchar)
nchar = tmp.length();
}

fmt_entry = " % " + std::to_string(nchar) + ".0f";
}


printf_epiworld("\nTransition Probabilities:\n");
for (size_t s_i = 0u; s_i < n_state; ++s_i)
Expand All @@ -4568,7 +4589,9 @@ inline std::vector< epiworld_double > DataBase<TSeq>::transition_probability(
{
printf_epiworld(" -");
} else {
printf_epiworld(" % 4.2f", res[s_i + s_j * n_state]);
printf_epiworld(
fmt_entry.c_str(), res[s_i + s_j * n_state]
);
}
}
printf_epiworld("\n");
Expand Down
3 changes: 2 additions & 1 deletion include/epiworld/database-bones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ class DataBase {
* @return std::vector< epiworld_double >
*/
std::vector< epiworld_double > transition_probability(
bool print = true
bool print = true,
bool normalize = true
) const;

bool operator==(const DataBase<TSeq> & other) const;
Expand Down
34 changes: 28 additions & 6 deletions include/epiworld/database-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,8 @@ inline void DataBase<TSeq>::reproductive_number(

template<typename TSeq>
inline std::vector< epiworld_double > DataBase<TSeq>::transition_probability(
bool print
bool print,
bool normalize
) const {

auto states_labels = model->get_states();
Expand All @@ -1184,7 +1185,8 @@ inline std::vector< epiworld_double > DataBase<TSeq>::transition_probability(

for (size_t s_i = 0; s_i < n_state; ++s_i)
{
epiworld_double daily_total = hist_total_counts[(t - 1) * n_state + s_i];
epiworld_double daily_total =
hist_total_counts[(t - 1) * n_state + s_i];

if (daily_total == 0)
continue;
Expand Down Expand Up @@ -1219,10 +1221,13 @@ inline std::vector< epiworld_double > DataBase<TSeq>::transition_probability(

}

for (size_t s_i = 0; s_i < n_state; ++s_i)
if (normalize)
{
for (size_t s_j = 0; s_j < n_state; ++s_j)
res[s_i + s_j * n_state] /= days_to_include[s_i];
for (size_t s_i = 0; s_i < n_state; ++s_i)
{
for (size_t s_j = 0; s_j < n_state; ++s_j)
res[s_i + s_j * n_state] /= days_to_include[s_i];
}
}

if (print)
Expand All @@ -1234,6 +1239,21 @@ inline std::vector< epiworld_double > DataBase<TSeq>::transition_probability(
nchar = l.length();

std::string fmt = " - %-" + std::to_string(nchar) + "s";

std::string fmt_entry = " % 4.2f";
if (!normalize)
{
nchar = 0u;
for (auto & l: res)

Check warning on line 1247 in include/epiworld/database-meat.hpp

View check run for this annotation

Codecov / codecov/patch

include/epiworld/database-meat.hpp#L1246-L1247

Added lines #L1246 - L1247 were not covered by tests
{
std::string tmp = std::to_string(l);
if (tmp.length() > nchar)

Check warning on line 1250 in include/epiworld/database-meat.hpp

View check run for this annotation

Codecov / codecov/patch

include/epiworld/database-meat.hpp#L1249-L1250

Added lines #L1249 - L1250 were not covered by tests
nchar = tmp.length();
}

fmt_entry = " % " + std::to_string(nchar) + ".0f";

Check warning on line 1254 in include/epiworld/database-meat.hpp

View check run for this annotation

Codecov / codecov/patch

include/epiworld/database-meat.hpp#L1254

Added line #L1254 was not covered by tests
}


printf_epiworld("\nTransition Probabilities:\n");
for (size_t s_i = 0u; s_i < n_state; ++s_i)
Expand All @@ -1245,7 +1265,9 @@ inline std::vector< epiworld_double > DataBase<TSeq>::transition_probability(
{
printf_epiworld(" -");
} else {
printf_epiworld(" % 4.2f", res[s_i + s_j * n_state]);
printf_epiworld(
fmt_entry.c_str(), res[s_i + s_j * n_state]
);
}
}
printf_epiworld("\n");
Expand Down

0 comments on commit 2a6700a

Please sign in to comment.