diff --git a/CHANGELOG.md b/CHANGELOG.md index c81332b2d34e..5af2dfbc9ced 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Current develop ### Added (new features/APIs/variables/...) +- [[PR 1060]](https://github.com/parthenon-hpc-lab/parthenon/pull/1060) Add the ability to request new MeshData/MeshBlockData objects by selecting variables by UID. - [[PR1039]](https://github.com/parthenon-hpc-lab/parthenon/pull/1039) Add ability to output custom coordinate positions for Visit/Paraview - [[PR1019](https://github.com/parthenon-hpc-lab/parthenon/pull/1019) Enable output for non-cell-centered variables diff --git a/src/interface/data_collection.cpp b/src/interface/data_collection.cpp index 275be3c7ff8f..81e6886e29e4 100644 --- a/src/interface/data_collection.cpp +++ b/src/interface/data_collection.cpp @@ -22,36 +22,14 @@ namespace parthenon { template -std::shared_ptr & -DataCollection::Add(const std::string &name, const std::shared_ptr &src, - const std::vector &field_names, const bool shallow) { - auto it = containers_.find(name); +std::shared_ptr &DataCollection::Add(const std::string &label) { + // error check for duplicate names + auto it = containers_.find(label); if (it != containers_.end()) { - if (!(it->second)->Contains(field_names)) { - PARTHENON_THROW(name + - "already exists in collection but does not contain field names"); - } return it->second; } - - auto c = std::make_shared(name); - c->Initialize(src.get(), field_names, shallow); - - Set(name, c); - - return containers_[name]; -} -template -std::shared_ptr &DataCollection::Add(const std::string &label, - const std::shared_ptr &src, - const std::vector &flags) { - return Add(label, src, flags, false); -} -template -std::shared_ptr &DataCollection::AddShallow(const std::string &label, - const std::shared_ptr &src, - const std::vector &flags) { - return Add(label, src, flags, true); + containers_[label] = std::make_shared(); + return containers_[label]; } std::shared_ptr> & diff --git a/src/interface/data_collection.hpp b/src/interface/data_collection.hpp index 6119bf0a8ae2..8dea579728c4 100644 --- a/src/interface/data_collection.hpp +++ b/src/interface/data_collection.hpp @@ -19,6 +19,8 @@ #include #include +#include "utils/error_checking.hpp" + namespace parthenon { class Mesh; /// The DataCollection class is an abstract container that contains at least a @@ -42,21 +44,35 @@ class DataCollection { void SetMeshPointer(Mesh *pmesh) { pmy_mesh_ = pmesh; } - std::shared_ptr &Add(const std::string &label, const std::shared_ptr &src, - const std::vector &flags, const bool shallow); - std::shared_ptr &Add(const std::string &label, const std::shared_ptr &src, - const std::vector &flags = {}); - std::shared_ptr &AddShallow(const std::string &label, const std::shared_ptr &src, - const std::vector &flags = {}); - std::shared_ptr &Add(const std::string &label) { - // error check for duplicate names - auto it = containers_.find(label); + template + std::shared_ptr &Add(const std::string &name, const std::shared_ptr &src, + const std::vector &fields, const bool shallow) { + auto it = containers_.find(name); if (it != containers_.end()) { + if (fields.size() && !(it->second)->ContainsExactly(fields)) { + PARTHENON_THROW(name + " already exists in collection but fields do not match."); + } return it->second; } - containers_[label] = std::make_shared(); - return containers_[label]; + + auto c = std::make_shared(name); + c->Initialize(src.get(), fields, shallow); + + Set(name, c); + + return containers_[name]; + } + template + std::shared_ptr &Add(const std::string &label, const std::shared_ptr &src, + const std::vector &fields = {}) { + return Add(label, src, fields, false); + } + template + std::shared_ptr &AddShallow(const std::string &label, const std::shared_ptr &src, + const std::vector &fields = {}) { + return Add(label, src, fields, true); } + std::shared_ptr &Add(const std::string &label); auto &Stages() { return containers_; } const auto &Stages() const { return containers_; } diff --git a/src/interface/mesh_data.cpp b/src/interface/mesh_data.cpp index 9f326c16e222..8e9172024b8b 100644 --- a/src/interface/mesh_data.cpp +++ b/src/interface/mesh_data.cpp @@ -16,31 +16,6 @@ namespace parthenon { -template -void MeshData::Initialize(const MeshData *src, - const std::vector &names, const bool shallow) { - if (src == nullptr) { - PARTHENON_THROW("src points at null"); - } - pmy_mesh_ = src->GetParentPointer(); - const int nblocks = src->NumBlocks(); - block_data_.resize(nblocks); - - grid = src->grid; - if (grid.type == GridType::two_level_composite) { - for (int i = 0; i < nblocks; i++) { - block_data_[i] = - pmy_mesh_->gmg_block_lists[src->grid.logical_level][i]->meshblock_data.Add( - stage_name_, src->GetBlockData(i), names, shallow); - } - } else { - for (int i = 0; i < nblocks; i++) { - block_data_[i] = pmy_mesh_->block_list[i]->meshblock_data.Add( - stage_name_, src->GetBlockData(i), names, shallow); - } - } -} - template void MeshData::Set(BlockList_t blocks, Mesh *pmesh, int ndim) { const int nblocks = blocks.size(); diff --git a/src/interface/mesh_data.hpp b/src/interface/mesh_data.hpp index e44e7def467e..e7f8c16e85b4 100644 --- a/src/interface/mesh_data.hpp +++ b/src/interface/mesh_data.hpp @@ -242,8 +242,32 @@ class MeshData { void Set(BlockList_t blocks, Mesh *pmesh, int ndim); void Set(BlockList_t blocks, Mesh *pmesh); - void Initialize(const MeshData *src, const std::vector &names, - const bool shallow); + + template + void Initialize(const MeshData *src, const std::vector &vars, + const bool shallow) { + if (src == nullptr) { + PARTHENON_THROW("src points at null"); + } + pmy_mesh_ = src->GetParentPointer(); + const int nblocks = src->NumBlocks(); + block_data_.resize(nblocks); + + // TODO(JMM/LFR): There is an edge case where if you call + // Initialize() on a set of meshblocks where some blocks contain + // the desired MeshBlockData object and some don't, this call will + // fail. (It will raise a runtime error due to a dictionary not + // being found.) This was present in the previous iteration of + // this code, as well as this iteration. Fixing this requires + // modifying DataCollection::GetOrAdd. In the future we should + // make that "just work (tm)." + grid = src->grid; + for (int i = 0; i < nblocks; ++i) { + auto pmbd = src->GetBlockData(i); + block_data_[i] = pmbd->GetBlockSharedPointer()->meshblock_data.Add( + stage_name_, pmbd, vars, shallow); + } + } const std::shared_ptr> &GetBlockData(int n) const { assert(n >= 0 && n < block_data_.size()); @@ -420,11 +444,17 @@ class MeshData { return true; } - bool Contains(const std::vector &names) const { - for (const auto &b : block_data_) { - if (!b->Contains(names)) return false; - } - return true; + // vars may be a subset of the MeshData object + template + bool Contains(const Vars_t &vars) const noexcept { + return std::all_of(block_data_.begin(), block_data_.end(), + [this, vars](const auto &b) { return b->Contains(vars); }); + } + // MeshData object must contain these vars and only these vars + template + bool ContainsExactly(const Vars_t &vars) const noexcept { + return std::all_of(block_data_.begin(), block_data_.end(), + [this, vars](const auto &b) { return b->ContainsExactly(vars); }); } SparsePackCache &GetSparsePackCache() { return sparse_pack_cache_; } diff --git a/src/interface/meshblock_data.cpp b/src/interface/meshblock_data.cpp index 5f685c0e119d..1b8185b4abce 100644 --- a/src/interface/meshblock_data.cpp +++ b/src/interface/meshblock_data.cpp @@ -79,45 +79,6 @@ void MeshBlockData::AddField(const std::string &base_name, const Metadata &me } } -// TODO(JMM): Move to unique IDs at some point -template -void MeshBlockData::Initialize(const MeshBlockData *src, - const std::vector &names, - const bool shallow_copy) { - assert(src != nullptr); - SetBlockPointer(src); - resolved_packages_ = src->resolved_packages_; - is_shallow_ = shallow_copy; - - auto add_var = [=](auto var) { - if (shallow_copy || var->IsSet(Metadata::OneCopy)) { - Add(var); - } else { - Add(var->AllocateCopy(pmy_block)); - } - }; - - // special case when the list of names is empty, copy everything - if (names.empty()) { - for (auto v : src->GetVariableVector()) { - add_var(v); - } - } else { - auto var_map = src->GetVariableMap(); - - for (const auto &name : names) { - bool found = false; - auto v = var_map.find(name); - if (v != var_map.end()) { - found = true; - add_var(v->second); - } - PARTHENON_REQUIRE_THROWS(found, "MeshBlockData::CopyFrom: Variable '" + name + - "' not found"); - } - } -} - /// Queries related to variable packs /// This is a helper function that queries the cache for the given pack. /// The strings are the keys and the lists are the values. diff --git a/src/interface/meshblock_data.hpp b/src/interface/meshblock_data.hpp index caf3be9c9d3c..f3e9583133b8 100644 --- a/src/interface/meshblock_data.hpp +++ b/src/interface/meshblock_data.hpp @@ -13,9 +13,11 @@ #ifndef INTERFACE_MESHBLOCK_DATA_HPP_ #define INTERFACE_MESHBLOCK_DATA_HPP_ +#include #include #include #include +#include #include #include #include @@ -109,8 +111,33 @@ class MeshBlockData { /// Create copy of MeshBlockData, possibly with a subset of named fields, /// and possibly shallow. Note when shallow=false, new storage is allocated /// for non-OneCopy vars, but the data from src is not actually deep copied - void Initialize(const MeshBlockData *src, const std::vector &names, - const bool shallow); + template + void Initialize(const MeshBlockData *src, const std::vector &vars, + const bool shallow_copy) { + PARTHENON_DEBUG_REQUIRE(src != nullptr, "Source data must be non-null."); + SetBlockPointer(src); + resolved_packages_ = src->resolved_packages_; + is_shallow_ = shallow_copy; + + auto add_var = [=](auto var) { + if (shallow_copy || var->IsSet(Metadata::OneCopy)) { + Add(var); + } else { + Add(var->AllocateCopy(pmy_block)); + } + }; + + // special case when the list of vars is empty, copy everything + if (vars.empty()) { + for (auto v : src->GetVariableVector()) { + add_var(v); + } + } else { + for (const auto &v : vars) { + add_var(src->GetVarPtr(v)); + } + } + } // // Queries related to Variable objects @@ -124,14 +151,13 @@ class MeshBlockData { const MapToVars &GetVariableMap() const noexcept { return varMap_; } std::shared_ptr> GetVarPtr(const std::string &label) const { - auto it = varMap_.find(label); - PARTHENON_REQUIRE_THROWS(it != varMap_.end(), + PARTHENON_REQUIRE_THROWS(varMap_.count(label), "Couldn't find variable '" + label + "'"); - return it->second; + return varMap_.at(label); } std::shared_ptr> GetVarPtr(const Uid_t &uid) const { PARTHENON_REQUIRE_THROWS(varUidMap_.count(uid), - "Variable ID " + std::to_string(uid) + "not found!"); + "Variable ID " + std::to_string(uid) + " not found!"); return varUidMap_.at(uid); } @@ -388,15 +414,18 @@ class MeshBlockData { return (my_keys == cmp_keys); } - bool Contains(const std::string &name) const noexcept { - if (varMap_.find(name) != varMap_.end()) return true; - return false; + bool Contains(const std::string &name) const noexcept { return varMap_.count(name); } + bool Contains(const Uid_t &uid) const noexcept { return varUidMap_.count(uid); } + template + bool Contains(const std::vector &vars) const noexcept { + return std::all_of(vars.begin(), vars.end(), + [this](const auto &v) { return this->Contains(v); }); } - bool Contains(const std::vector &names) const noexcept { - for (const auto &name : names) { - if (!Contains(name)) return false; - } - return true; + template + bool ContainsExactly(const std::vector &vars) const noexcept { + // JMM: Assumes vars contains no duplicates. But that would have + // been caught elsewhere because `MeshBlockData::Add` would have failed. + return Contains(vars) && (vars.size() == varVector_.size()); } void SetAllVariablesToInitialized() { @@ -419,6 +448,9 @@ class MeshBlockData { int sparse_id = InvalidSparseID); void Add(std::shared_ptr> var) noexcept { + if (varUidMap_.count(var->GetUniqueID())) { + PARTHENON_THROW("Tried to add variable " + var->label() + " twice!"); + } varVector_.push_back(var); varMap_[var->label()] = var; varUidMap_[var->GetUniqueID()] = var; diff --git a/src/mesh/mesh.hpp b/src/mesh/mesh.hpp index be0da22d8949..85db17f34331 100644 --- a/src/mesh/mesh.hpp +++ b/src/mesh/mesh.hpp @@ -123,8 +123,8 @@ class Mesh { std::map gmg_block_lists; std::map>> gmg_mesh_data; - int GetGMGMaxLevel() { return current_level; } - int GetGMGMinLevel() { return gmg_min_logical_level_; } + int GetGMGMaxLevel() const { return current_level; } + int GetGMGMinLevel() const { return gmg_min_logical_level_; } // functions void Initialize(bool init_problem, ParameterInput *pin, ApplicationInput *app_in); diff --git a/tst/unit/test_data_collection.cpp b/tst/unit/test_data_collection.cpp index 2e6c3d108133..3e84480c1e7e 100644 --- a/tst/unit/test_data_collection.cpp +++ b/tst/unit/test_data_collection.cpp @@ -91,7 +91,8 @@ TEST_CASE("Adding MeshBlockData objects to a DataCollection", "[DataCollection]" } } AND_WHEN("We want only a subset of variables in a new MeshBlockData") { - // reset vars + // reset vars so that we can check this is overwritten/or is a + // new stage par_for( loop_pattern_flatrange_tag, "init vars", DevExecSpace(), 0, 0, KOKKOS_LAMBDA(const int i) { v2(0) = 222; }); @@ -116,5 +117,35 @@ TEST_CASE("Adding MeshBlockData objects to a DataCollection", "[DataCollection]" REQUIRE(hxv2(0) == hv2(0)); } } + AND_WHEN("We want only a subset of variables in a new MeshBlockData by UID") { + // reset vars so that we can check this is overwritten/or is a + // new stage + par_for( + loop_pattern_flatrange_tag, "init vars", DevExecSpace(), 0, 0, + KOKKOS_LAMBDA(const int i) { v2(0) = 222; }); + std::vector uids; + uids.push_back(mbd->UniqueID("var2")); + uids.push_back(mbd->UniqueID("var3")); + auto x = d.Add("part", mbd, uids); + THEN("Requesting the missing variables should throw") { + REQUIRE_THROWS(x->Get("var1")); + } + AND_THEN("Requesting the specified variables should work as expected") { + auto &xv2 = x->Get("var2").data; + auto &xv3 = x->Get("var3").data; + par_for( + loop_pattern_flatrange_tag, "init vars", DevExecSpace(), 0, 0, + KOKKOS_LAMBDA(const int i) { + xv2(0) = 22; + xv3(0) = 33; + }); + auto hv2 = v2.GetHostMirrorAndCopy(); + auto hv3 = v3.GetHostMirrorAndCopy(); + auto hxv2 = xv2.GetHostMirrorAndCopy(); + auto hxv3 = xv3.GetHostMirrorAndCopy(); + REQUIRE(hxv3(0) != hv3(0)); + REQUIRE(hxv2(0) == hv2(0)); + } + } } }