From 689eea33cddb42f28a63238274ef12ff3135c4e2 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Wed, 22 May 2024 09:21:19 +0200 Subject: [PATCH] Feature: High-level network specification (#2050) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement a high-level network specification as proposed in #418. It does not include support for gap junctions to allow the use of domain decomposition for some distributed network generation. The general idea is a DSL based on set algebra, which operates on the set of all possible connections, by selecting based on different criteria, such as the distance between cells or lists of labels. By operating on all possible connections, a separate definition of cell populations becomes unnecessary. An example for selecting all inter-cell connections with a certain source and destination label is: `(intersect (inter-cell) (source-label \"detector\") (destination-label \"syn\"))` For parameters such as weight and delay, a value can be defined in the DSL in a similar way with the usual mathematical operations available. An example would be: `(max 0.1 (exp (mul -0.5 (distance))))` The position of each connection site is calculated by resolving the local position on the cell and applying an isometry, which is provided by a new optional function of the recipe. In contrast to the usage of policies to select a member within a locset, each site is treated individually and can be distinguished by its position. Internally, some steps have been implemented in an attempt to reduce the overhead of generating connections: - Pre-select source and destination sites based on the selection to reduce the sampling space when possible - If selection is limited to a maximum distance, use an octree for efficient spatial sampling - When using MPI, only instantiate local cells and exchange source sites in a ring communication pattern to overlap communication and sampling. In addition, this reduces memory usage, since only the current and next source sites have to be stored in memory during the exchange process. Custom selection and value functions can still be provided by storing the wrapped function in a dictionary with an associated label, which can then be used in the DSL. Some challenges remain. In particular, how to handle combined explicit connections returned by `connections_on` and the new way to describe a network. Also, the use of non-blocking MPI is not easily integrated into the current context types, and the dry-run context is not supported so far. # Example A (trimmed) example in Python, where a ring connection combined with random connections based on the distance: ```py class recipe(arbor.recipe): def cell_isometry(self, gid): # place cells with equal distance on a circle radius = 500.0 # μm angle = 2.0 * math.pi * gid / self.ncells return arbor.isometry.translate(radius * math.cos(angle), radius * math.sin(angle), 0) def network_description(self): seed = 42 # create a chain ring = f"(chain (gid-range 0 {self.ncells}))" # connect front and back of chain to form ring ring = f"(join {ring} (intersect (source-cell {self.ncells - 1}) (destination-cell 0)))" # Create random connections with probability inversely proportional to the distance within a # radius max_dist = 400.0 # μm probability = f"(div (sub {max_dist} (distance)) {max_dist})" rand = f"(intersect (random {seed} {probability}) (distance-lt {max_dist}))" # combine ring with random selection s = f"(join {ring} {rand})" # restrict to inter-cell connections and certain source / destination labels s = f"(intersect {s} (inter-cell) (source-label \"detector\") (destination-label \"syn\"))" # normal distributed weight with mean 0.02 μS, standard deviation 0.01 μS # and truncated to [0.005, 0.035] w = f"(truncated-normal-distribution {seed} 0.02 0.01 0.005 0.035)" # fixed delay d = "(scalar 5.0)" # ms delay return arbor.network_description(s, w, d, {}) ``` Co-authored-by: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> --- arbor/CMakeLists.txt | 2 + arbor/communication/communicator.cpp | 53 +- arbor/communication/communicator.hpp | 13 +- arbor/communication/distributed_for_each.hpp | 185 +++ arbor/communication/dry_run_context.cpp | 17 + arbor/communication/mpi.hpp | 57 + arbor/communication/mpi_context.cpp | 68 + arbor/connection.hpp | 6 +- arbor/distributed_context.hpp | 86 + arbor/domain_decomposition.cpp | 23 +- arbor/include/arbor/common_types.hpp | 13 + arbor/include/arbor/domain_decomposition.hpp | 6 +- arbor/include/arbor/math.hpp | 13 + arbor/include/arbor/network.hpp | 331 ++++ arbor/include/arbor/network_generation.hpp | 20 + arbor/include/arbor/recipe.hpp | 12 +- arbor/include/arbor/simulation.hpp | 2 +- arbor/network.cpp | 1453 +++++++++++++++++ arbor/network_impl.cpp | 308 ++++ arbor/network_impl.hpp | 67 + arbor/simulation.cpp | 11 +- arbor/threading/threading.cpp | 8 + arbor/threading/threading.hpp | 6 +- arbor/util/spatial_tree.hpp | 190 +++ arbor/util/visit_variant.hpp | 41 + arborio/CMakeLists.txt | 1 + arborio/include/arborio/networkio.hpp | 44 + arborio/networkio.cpp | 369 +++++ doc/concepts/interconnectivity.rst | 278 +++- doc/cpp/interconnectivity.rst | 293 +++- doc/python/interconnectivity.rst | 93 +- doc/python/recipe.rst | 11 + example/CMakeLists.txt | 1 + example/network_description/CMakeLists.txt | 4 + example/network_description/branch_cell.hpp | 132 ++ .../network_description.cpp | 343 ++++ example/network_description/readme.md | 3 + python/CMakeLists.txt | 1 + python/example/network_description.py | 182 +++ python/network.cpp | 163 ++ python/pyarb.cpp | 2 + python/recipe.cpp | 4 + python/recipe.hpp | 27 +- scripts/run_cpp_examples.sh | 2 + scripts/run_python_examples.sh | 1 + test/unit-distributed/CMakeLists.txt | 6 +- test/unit-distributed/test_communicator.cpp | 8 +- .../test_distributed_for_each.cpp | 92 ++ .../test_network_generation.cpp | 167 ++ test/unit/CMakeLists.txt | 2 + test/unit/test_domain_decomposition.cpp | 9 + test/unit/test_network.cpp | 817 +++++++++ test/unit/test_s_expr.cpp | 147 +- test/unit/test_spatial_tree.cpp | 155 ++ 54 files changed, 6259 insertions(+), 89 deletions(-) create mode 100644 arbor/communication/distributed_for_each.hpp create mode 100644 arbor/include/arbor/network.hpp create mode 100644 arbor/include/arbor/network_generation.hpp create mode 100644 arbor/network.cpp create mode 100644 arbor/network_impl.cpp create mode 100644 arbor/network_impl.hpp create mode 100644 arbor/util/spatial_tree.hpp create mode 100644 arbor/util/visit_variant.hpp create mode 100644 arborio/include/arborio/networkio.hpp create mode 100644 arborio/networkio.cpp create mode 100644 example/network_description/CMakeLists.txt create mode 100644 example/network_description/branch_cell.hpp create mode 100644 example/network_description/network_description.cpp create mode 100644 example/network_description/readme.md create mode 100755 python/example/network_description.py create mode 100644 python/network.cpp create mode 100644 test/unit-distributed/test_distributed_for_each.cpp create mode 100644 test/unit-distributed/test_network_generation.cpp create mode 100644 test/unit/test_network.cpp create mode 100644 test/unit/test_spatial_tree.cpp diff --git a/arbor/CMakeLists.txt b/arbor/CMakeLists.txt index c456e1e355..1b905d2932 100644 --- a/arbor/CMakeLists.txt +++ b/arbor/CMakeLists.txt @@ -44,6 +44,8 @@ set(arbor_sources morph/segment_tree.cpp morph/stitch.cpp merge_events.cpp + network.cpp + network_impl.cpp simulation.cpp partition_load_balance.cpp profile/clock.cpp diff --git a/arbor/communication/communicator.cpp b/arbor/communication/communicator.cpp index e258d0226b..e1d3de887c 100644 --- a/arbor/communication/communicator.cpp +++ b/arbor/communication/communicator.cpp @@ -14,6 +14,7 @@ #include "connection.hpp" #include "distributed_context.hpp" #include "execution_context.hpp" +#include "network_impl.hpp" #include "profile/profiler_macro.hpp" #include "threading/threading.hpp" #include "util/partition.hpp" @@ -24,14 +25,12 @@ namespace arb { -communicator::communicator(const recipe& rec, - const domain_decomposition& dom_dec, - execution_context& ctx): num_total_cells_{rec.num_cells()}, - num_local_cells_{dom_dec.num_local_cells()}, - num_local_groups_{dom_dec.num_groups()}, - num_domains_{(cell_size_type) ctx.distributed->size()}, - distributed_{ctx.distributed}, - thread_pool_{ctx.thread_pool} {} +communicator::communicator(const recipe& rec, const domain_decomposition& dom_dec, context ctx): + num_total_cells_{rec.num_cells()}, + num_local_cells_{dom_dec.num_local_cells()}, + num_local_groups_{dom_dec.num_groups()}, + num_domains_{(cell_size_type)ctx->distributed->size()}, + ctx_(std::move(ctx)) {} constexpr inline bool is_external(cell_gid_type c) { @@ -55,7 +54,7 @@ cell_member_type global_cell_of(const cell_member_type& c) { return {c.gid | msb, c.index}; } -void communicator::update_connections(const connectivity& rec, +void communicator::update_connections(const recipe& rec, const domain_decomposition& dom_dec, const label_resolution_map& source_resolution_map, const label_resolution_map& target_resolution_map) { @@ -67,6 +66,9 @@ void communicator::update_connections(const connectivity& rec, index_divisions_.clear(); PL(); + // Construct connections from high-level specification + auto generated_connections = generate_connections(rec, ctx_, dom_dec); + // Make a list of local cells' connections // -> gid_connections // Count the number of local connections (i.e. connections terminating on this domain) @@ -114,9 +116,18 @@ void communicator::update_connections(const connectivity& rec, } part_ext_connections.push_back(gid_ext_connections.size()); } + for (const auto& c: generated_connections) { + auto sgid = c.source.gid; + if (sgid >= num_total_cells_) { + throw arb::bad_connection_source_gid(c.source.gid, sgid, num_total_cells_); + } + const auto src = dom_dec.gid_domain(sgid); + src_domains.push_back(src); + src_counts[src]++; + } util::make_partition(connection_part_, src_counts); - auto n_cons = gid_connections.size(); + auto n_cons = gid_connections.size() + generated_connections.size(); auto n_ext_cons = gid_ext_connections.size(); PL(); @@ -132,6 +143,7 @@ void communicator::update_connections(const connectivity& rec, auto target_resolver = resolver(&target_resolution_map); for (const auto index: util::make_span(num_local_cells_)) { const auto tgt_gid = gids[index]; + const auto iod = dom_dec.index_on_domain(tgt_gid); auto source_resolver = resolver(&source_resolution_map); for (const auto cidx: util::make_span(part_connections[index], part_connections[index+1])) { const auto& conn = gid_connections[cidx]; @@ -141,7 +153,7 @@ void communicator::update_connections(const connectivity& rec, auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target); auto offset = offsets[*src_domain]++; ++src_domain; - connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index}; + connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, iod}; } for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) { const auto& conn = gid_ext_connections[cidx]; @@ -149,10 +161,15 @@ void communicator::update_connections(const connectivity& rec, auto src_gid = conn.source.rid; if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid); auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target); - ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, index}; + ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, iod}; ++ext; } } + for (const auto& c: generated_connections) { + auto offset = offsets[*src_domain]++; + ++src_domain; + connections[offset] = c; + } PL(); PE(init:communicator:update:index); @@ -167,7 +184,7 @@ void communicator::update_connections(const connectivity& rec, // Sort the connections for each domain. // This is num_domains_ independent sorts, so it can be parallelized trivially. const auto& cp = connection_part_; - threading::parallel_for::apply(0, num_domains_, thread_pool_.get(), + threading::parallel_for::apply(0, num_domains_, ctx_->thread_pool.get(), [&](cell_size_type i) { util::sort(util::subrange_view(connections, cp[i], cp[i+1])); }); @@ -193,7 +210,7 @@ time_type communicator::min_delay() { res = std::accumulate(ext_connections_.delays.begin(), ext_connections_.delays.end(), res, [](auto&& acc, time_type del) { return std::min(acc, del); }); - res = distributed_->min(res); + res = ctx_->distributed->min(res); return res; } @@ -206,7 +223,7 @@ communicator::exchange(std::vector local_spikes) { PE(communication:exchange:gather); // global all-to-all to gather a local copy of the global spike list on each node. - auto global_spikes = distributed_->gather_spikes(local_spikes); + auto global_spikes = ctx_->distributed->gather_spikes(local_spikes); num_spikes_ += global_spikes.size(); PL(); @@ -217,7 +234,7 @@ communicator::exchange(std::vector local_spikes) { local_spikes.end(), [this] (const auto& s) { return !remote_spike_filter_(s); })); } - auto remote_spikes = distributed_->remote_gather_spikes(local_spikes); + auto remote_spikes = ctx_->distributed->remote_gather_spikes(local_spikes); PL(); PE(communication:exchange:gather:remote:post_process); @@ -231,8 +248,8 @@ communicator::exchange(std::vector local_spikes) { } void communicator::set_remote_spike_filter(const spike_predicate& p) { remote_spike_filter_ = p; } -void communicator::remote_ctrl_send_continue(const epoch& e) { distributed_->remote_ctrl_send_continue(e); } -void communicator::remote_ctrl_send_done() { distributed_->remote_ctrl_send_done(); } +void communicator::remote_ctrl_send_continue(const epoch& e) { ctx_->distributed->remote_ctrl_send_continue(e); } +void communicator::remote_ctrl_send_done() { ctx_->distributed->remote_ctrl_send_done(); } // Given // * a set of connections and an index into the set diff --git a/arbor/communication/communicator.hpp b/arbor/communication/communicator.hpp index 9b07f56ae6..4a73d10ad5 100644 --- a/arbor/communication/communicator.hpp +++ b/arbor/communication/communicator.hpp @@ -1,11 +1,11 @@ #pragma once #include -#include -#include #include +#include #include +#include #include #include @@ -40,7 +40,7 @@ class ARB_ARBOR_API communicator { explicit communicator(const recipe& rec, const domain_decomposition& dom_dec, - execution_context& ctx); + context ctx); /// The range of event queues that belong to cells in group i. std::pair group_queue_range(cell_size_type i); @@ -78,7 +78,7 @@ class ARB_ARBOR_API communicator { void remote_ctrl_send_continue(const epoch&); void remote_ctrl_send_done(); - void update_connections(const connectivity& rec, + void update_connections(const recipe& rec, const domain_decomposition& dom_dec, const label_resolution_map& source_resolution_map, const label_resolution_map& target_resolution_map); @@ -98,7 +98,7 @@ class ARB_ARBOR_API communicator { for (const auto& con: cons) { idx_on_domain.push_back(con.index_on_domain); srcs.push_back(con.source); - dests.push_back(con.destination); + dests.push_back(con.target); weights.push_back(con.weight); delays.push_back(con.delay); } @@ -136,10 +136,9 @@ class ARB_ARBOR_API communicator { // Currently we have no partitions/indices/acceleration structures connection_list ext_connections_; - distributed_context_handle distributed_; - task_system_handle thread_pool_; std::uint64_t num_spikes_ = 0u; std::uint64_t num_local_events_ = 0u; + context ctx_; }; } // namespace arb diff --git a/arbor/communication/distributed_for_each.hpp b/arbor/communication/distributed_for_each.hpp new file mode 100644 index 0000000000..2a1399639b --- /dev/null +++ b/arbor/communication/distributed_for_each.hpp @@ -0,0 +1,185 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "distributed_context.hpp" +#include "util/range.hpp" + +namespace arb { + +namespace impl { +template +void for_each_in_tuple(FUNC&& func, std::tuple& t, std::index_sequence) { + (func(Is, std::get(t)), ...); +} + +template +void for_each_in_tuple(FUNC&& func, std::tuple& t) { + for_each_in_tuple(func, t, std::index_sequence_for()); +} + +template +void for_each_in_tuple_pair(FUNC&& func, + std::tuple& t1, + std::tuple& t2, + std::index_sequence) { + (func(Is, std::get(t1), std::get(t2)), ...); +} + +template +void for_each_in_tuple_pair(FUNC&& func, std::tuple& t1, std::tuple& t2) { + for_each_in_tuple_pair(func, t1, t2, std::index_sequence_for()); +} + +} // namespace impl + + +/* + * Collective operation, calling func on args supplied by each rank exactly once. The order of calls + * is unspecified. Requires + * + * - Item = util::range::value_type to be identical across all ranks + * - Item is trivially_copyable + * - Alignment of Item must not exceed std::max_align_t + * - func to be a callable type with signature + * void func(util::range...) + * - func must not modify contents of range + * - All ranks in distributed must call this collectively. + */ +template +void distributed_for_each(FUNC&& func, + const distributed_context& distributed, + const util::range&... args) { + + static_assert(sizeof...(args) > 0); + auto arg_tuple = std::forward_as_tuple(args...); + + struct vec_info { + std::size_t offset; // offset in bytes + std::size_t size; // size in bytes + }; + + std::array info; + std::size_t buffer_size = 0; + + // Compute offsets in bytes for each vector when placed in common buffer + { + std::size_t offset = info.size() * sizeof(vec_info); + impl::for_each_in_tuple( + [&](std::size_t i, auto&& vec) { + using T = typename std::remove_reference_t::value_type; + static_assert(std::is_trivially_copyable_v); + static_assert(alignof(std::max_align_t) >= alignof(T)); + static_assert(alignof(std::max_align_t) % alignof(T) == 0); + + // make sure alignment of offset fulfills requirement + const auto alignment_excess = offset % alignof(T); + offset += alignment_excess > 0 ? alignof(T) - (alignment_excess) : 0; + + const auto size_in_bytes = vec.size() * sizeof(T); + + info[i].size = size_in_bytes; + info[i].offset = offset; + + buffer_size = offset + size_in_bytes; + offset += size_in_bytes; + }, + arg_tuple); + } + + // compute maximum buffer size between ranks, such that we only allocate once + const std::size_t max_buffer_size = distributed.max(buffer_size); + + std::tuple::value_type*>...> + ranges; + + if (max_buffer_size == info.size() * sizeof(vec_info)) { + // if all empty, call function with empty ranges for each step and exit + impl::for_each_in_tuple_pair( + [&](std::size_t i, auto&& vec, auto&& r) { + using T = typename std::remove_reference_t::value_type; + r = util::range(nullptr, nullptr); + }, + arg_tuple, + ranges); + + for (int step = 0; step < distributed.size(); ++step) { std::apply(func, ranges); } + return; + } + + // use malloc for std::max_align_t alignment + auto deleter = [](char* ptr) { std::free(ptr); }; + std::unique_ptr buffer((char*)std::malloc(max_buffer_size), deleter); + std::unique_ptr recv_buffer( + (char*)std::malloc(max_buffer_size), deleter); + + // copy offset and size info to front of buffer + std::memcpy(buffer.get(), info.data(), info.size() * sizeof(vec_info)); + + // copy each vector to each location in buffer + impl::for_each_in_tuple( + [&](std::size_t i, auto&& vec) { + using T = typename std::remove_reference_t::value_type; + std::copy(vec.begin(), vec.end(), (T*)(buffer.get() + info[i].offset)); + }, + arg_tuple); + + + const auto my_rank = distributed.id(); + const auto left_rank = my_rank == 0 ? distributed.size() - 1 : my_rank - 1; + const auto right_rank = my_rank == distributed.size() - 1 ? 0 : my_rank + 1; + + // exchange buffer in ring pattern and apply function at each step + for (int step = 0; step < distributed.size() - 1; ++step) { + // always expect to recieve the max size but send actual size. MPI_recv only expects a max + // size, not the actual size. + const auto current_info = (const vec_info*)buffer.get(); + + auto request = distributed.send_recv_nonblocking(max_buffer_size, + recv_buffer.get(), + right_rank, + current_info[info.size() - 1].offset + current_info[info.size() - 1].size, + buffer.get(), + left_rank, + 0); + + // update ranges + impl::for_each_in_tuple_pair( + [&](std::size_t i, auto&& vec, auto&& r) { + using T = typename std::remove_reference_t::value_type; + r = util::range((T*)(buffer.get() + current_info[i].offset), + (T*)(buffer.get() + current_info[i].offset + current_info[i].size)); + }, + arg_tuple, + ranges); + + // call provided function with ranges pointing to current buffer + std::apply(func, ranges); + + request.finalize(); + buffer.swap(recv_buffer); + } + + // final step does not require any exchange + const auto current_info = (const vec_info*)buffer.get(); + impl::for_each_in_tuple_pair( + [&](std::size_t i, auto&& vec, auto&& r) { + using T = typename std::remove_reference_t::value_type; + r = util::range((T*)(buffer.get() + current_info[i].offset), + (T*)(buffer.get() + current_info[i].offset + current_info[i].size)); + }, + arg_tuple, + ranges); + + // call provided function with ranges pointing to current buffer + std::apply(func, ranges); +} + +} // namespace arb diff --git a/arbor/communication/dry_run_context.cpp b/arbor/communication/dry_run_context.cpp index 4bc17f4d13..46ed0695da 100644 --- a/arbor/communication/dry_run_context.cpp +++ b/arbor/communication/dry_run_context.cpp @@ -89,6 +89,23 @@ struct dry_run_context_impl { return std::vector(num_ranks_, value); } + std::vector gather_all(std::size_t value) const { + return std::vector(num_ranks_, value); + } + + distributed_request send_recv_nonblocking(std::size_t dest_count, + void* dest_data, + int dest, + std::size_t source_count, + const void* source_data, + int source, + int tag) const { + throw arbor_internal_error("send_recv_nonblocking: not implemented for dry run conext."); + + return distributed_request{ + std::make_unique()}; + } + int id() const { return 0; } int size() const { return num_ranks_; } diff --git a/arbor/communication/mpi.hpp b/arbor/communication/mpi.hpp index 8093c22a54..3b760d4f67 100644 --- a/arbor/communication/mpi.hpp +++ b/arbor/communication/mpi.hpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include @@ -315,5 +317,60 @@ T broadcast(int root, MPI_Comm comm) { return value; } +inline std::vector isend(std::size_t num_bytes, + const void* data, + int dest, + int tag, + MPI_Comm comm) { + constexpr std::size_t max_msg_size = static_cast(std::numeric_limits::max()); + + std::vector requests; + + for (std::size_t idx = 0; idx < num_bytes; idx += max_msg_size) { + requests.emplace_back(); + MPI_OR_THROW(MPI_Isend, + reinterpret_cast(const_cast(data)) + idx, + static_cast(std::min(max_msg_size, num_bytes - idx)), + MPI_BYTE, + dest, + tag, + comm, + &(requests.back())); + } + + return requests; +} + +inline std::vector irecv(std::size_t num_bytes, + void* data, + int source, + int tag, + MPI_Comm comm) { + constexpr std::size_t max_msg_size = static_cast(std::numeric_limits::max()); + + std::vector requests; + + for (std::size_t idx = 0; idx < num_bytes; idx += max_msg_size) { + requests.emplace_back(); + MPI_OR_THROW(MPI_Irecv, + reinterpret_cast(data) + idx, + static_cast(std::min(max_msg_size, num_bytes - idx)), + MPI_BYTE, + source, + tag, + comm, + &(requests.back())); + } + + return requests; +} + +inline void wait_all(std::vector requests) { + if(!requests.empty()) { + MPI_OR_THROW( + MPI_Waitall, static_cast(requests.size()), requests.data(), MPI_STATUSES_IGNORE); + } +} + } // namespace mpi } // namespace arb diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp index 8109a5a889..85dd6e2673 100644 --- a/arbor/communication/mpi_context.cpp +++ b/arbor/communication/mpi_context.cpp @@ -5,6 +5,7 @@ #error "build only if MPI is enabled" #endif +#include #include #include @@ -73,6 +74,63 @@ struct mpi_context_impl { return mpi::gather(value, root, comm_); } + std::vector gather_all(std::size_t value) const { + return mpi::gather_all(value, comm_); + } + + distributed_request send_recv_nonblocking(std::size_t recv_count, + void* recv_data, + int source_id, + std::size_t send_count, + const void* send_data, + int dest_id, + int tag) const { + + // Return dummy request of nothing to do + if (!recv_count && !send_count) + return distributed_request{ + std::make_unique()}; + if(recv_count && !recv_data) + throw arbor_internal_error( + "send_recv_nonblocking: recv_data is null."); + + if(send_count && !send_data) + throw arbor_internal_error( + "send_recv_nonblocking: send_data is null."); + + if (recv_data == send_data) + throw arbor_internal_error( + "send_recv_nonblocking: recv_data and send_data must not be the same."); + + auto recv_requests = mpi::irecv(recv_count, recv_data, source_id, tag, comm_); + auto send_requests = mpi::isend(send_count, send_data, dest_id, tag, comm_); + + struct mpi_send_recv_request : public distributed_request::distributed_request_interface { + std::vector recv_requests, send_requests; + + mpi_send_recv_request(std::vector recv_requests, + std::vector send_requests): + recv_requests(std::move(recv_requests)), + send_requests(std::move(send_requests)) {} + + void finalize() override { + if (!recv_requests.empty()) { + mpi::wait_all(std::move(recv_requests)); + } + + if (!send_requests.empty()) { + mpi::wait_all(std::move(send_requests)); + } + }; + + ~mpi_send_recv_request() override { this->finalize(); } + }; + + return distributed_request{ + std::unique_ptr( + new mpi_send_recv_request{std::move(recv_requests), std::move(send_requests)})}; + } + std::string name() const { return "MPI"; } int id() const { return rank_; } int size() const { return size_; } @@ -142,6 +200,16 @@ struct remote_context_impl { return mpi_.gather_cell_labels_and_gids(local_labels_and_gids); } + distributed_request send_recv_nonblocking(std::size_t recv_count, + void* recv_data, + int source_id, + std::size_t send_count, + const void* send_data, + int dest_id, + int tag) const { + return mpi_.send_recv_nonblocking(recv_count, recv_data, source_id, send_count, send_data, dest_id, tag); + } + template std::vector gather(T value, int root) const { return mpi_.gather(value, root); } std::string name() const { return "MPIRemote"; } int id() const { return mpi_.id(); } diff --git a/arbor/connection.hpp b/arbor/connection.hpp index e83fa28722..276ed217aa 100644 --- a/arbor/connection.hpp +++ b/arbor/connection.hpp @@ -10,7 +10,7 @@ namespace arb { struct connection { cell_member_type source = {0, 0}; - cell_lid_type destination = 0; + cell_lid_type target = 0; float weight = 0.0f; float delay = 0.0f; cell_size_type index_on_domain = cell_gid_type(-1); @@ -18,7 +18,7 @@ struct connection { inline spike_event make_event(const connection& c, const spike& s) { - return {c.destination, s.time + c.delay, c.weight}; + return {c.target, s.time + c.delay, c.weight}; } // connections are sorted by source id @@ -30,7 +30,7 @@ static inline bool operator<(cell_member_type lhs, const connection& rhs) { ret } // namespace arb static inline std::ostream& operator<<(std::ostream& o, arb::connection const& con) { - return o << "con [" << con.source << " -> " << con.destination + return o << "con [" << con.source << " -> " << con.target << " : weight " << con.weight << ", delay " << con.delay << ", index " << con.index_on_domain << "]"; diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp index 76e8f4c0a9..9a31613f92 100644 --- a/arbor/distributed_context.hpp +++ b/arbor/distributed_context.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -34,6 +35,34 @@ namespace arb { #define ARB_COLLECTIVE_TYPES_ float, double, int, unsigned, long, unsigned long, long long, unsigned long long + +// A helper struct, representing a request for data exchange. +// After calling finalize() or destruction, the data exchange is guaranteed to be finished. +struct distributed_request { + struct distributed_request_interface { + virtual void finalize() {}; + + virtual ~distributed_request_interface() = default; + }; + + inline void finalize() { + if (impl) { + impl->finalize(); + impl.reset(); + } + } + + ~distributed_request() { + try { + finalize(); + } + catch (...) { + } + } + + std::unique_ptr impl; +}; + // Defines the concept/interface for a distributed communication context. // // Uses value-semantic type erasure to define the interface, so that @@ -84,6 +113,27 @@ class distributed_context { return impl_->gather(value, root); } + template + distributed_request send_recv_nonblocking(std::size_t recv_count, + T* recv_data, + int source_id, + std::size_t send_count, + const T* send_data, + int dest_id, + int tag) const { + static_assert(std::is_trivially_copyable::value, + "send_recv_nonblocking: Type T must be trivially copyable for memcpy or MPI send / " + "recv using MPI_BYTE."); + + return impl_->send_recv_nonblocking(recv_count * sizeof(T), + recv_data, + source_id, + send_count * sizeof(T), + send_data, + dest_id, + tag); + } + int id() const { return impl_->id(); } @@ -119,6 +169,13 @@ class distributed_context { gather_cell_labels_and_gids(const cell_labels_and_gids& local_labels_and_gids) const = 0; virtual std::vector gather(std::string value, int root) const = 0; + virtual distributed_request send_recv_nonblocking(std::size_t recv_count, + void* recv_data, + int source_id, + std::size_t send_count, + const void* send_data, + int dest_id, + int tag) const = 0; virtual int id() const = 0; virtual int size() const = 0; virtual void barrier() const = 0; @@ -160,6 +217,16 @@ class distributed_context { gather(std::string value, int root) const override { return wrapped.gather(value, root); } + distributed_request send_recv_nonblocking(std::size_t recv_count, + void* recv_data, + int source_id, + std::size_t send_count, + const void* send_data, + int dest_id, + int tag) const override { + return wrapped.send_recv_nonblocking( + recv_count, recv_data, source_id, send_count, send_data, dest_id, tag); + } int id() const override { return wrapped.id(); } @@ -220,6 +287,25 @@ struct local_context { return {std::move(value)}; } + distributed_request send_recv_nonblocking(std::size_t dest_count, + void* dest_data, + int dest, + std::size_t source_count, + const void* source_data, + int source, + int tag) const { + if (source != 0 || dest != 0) + throw arbor_internal_error( + "send_recv_nonblocking: source and destination id must be 0 for local context."); + if (dest_count != source_count) + throw arbor_internal_error( + "send_recv_nonblocking: dest_count not equal to source_count."); + std::memcpy(dest_data, source_data, source_count); + + return distributed_request{ + std::make_unique()}; + } + int id() const { return 0; } int size() const { return 1; } diff --git a/arbor/domain_decomposition.cpp b/arbor/domain_decomposition.cpp index f9c14a03d9..a91608a25a 100644 --- a/arbor/domain_decomposition.cpp +++ b/arbor/domain_decomposition.cpp @@ -1,11 +1,13 @@ #include #include #include +#include -#include +#include +#include #include +#include #include -#include #include "execution_context.hpp" #include "util/partition.hpp" @@ -20,13 +22,18 @@ domain_decomposition::domain_decomposition(const recipe& rec, partition_gid_domain(const gathered_vector& divs, unsigned domains) { auto rank_part = util::partition_view(divs.partition()); for (auto rank: count_along(rank_part)) { + cell_size_type index_on_domain = 0; for (auto gid: util::subrange_view(divs.values(), rank_part[rank])) { - gid_map[gid] = rank; + gid_map[gid] = {rank, index_on_domain}; + ++index_on_domain; } } } - int operator()(cell_gid_type gid) const { return gid_map.at(gid); } - std::unordered_map gid_map; + std::pair operator()(cell_gid_type gid) const { + return gid_map.at(gid); + } + // Maps gid to domain index and cell index on domain + std::unordered_map> gid_map; }; const auto* dist = ctx->distributed.get(); @@ -69,7 +76,11 @@ domain_decomposition::domain_decomposition(const recipe& rec, } int domain_decomposition::gid_domain(cell_gid_type gid) const { - return gid_domain_(gid); + return gid_domain_(gid).first; +} + +cell_size_type domain_decomposition::index_on_domain(cell_gid_type gid) const { + return gid_domain_(gid).second; } int domain_decomposition::num_domains() const { diff --git a/arbor/include/arbor/common_types.hpp b/arbor/include/arbor/common_types.hpp index e2d1eccdc7..32737f883b 100644 --- a/arbor/include/arbor/common_types.hpp +++ b/arbor/include/arbor/common_types.hpp @@ -70,6 +70,19 @@ struct lid_range { begin(b), end(e) {} }; +// Global range of indices with given step size. + +struct gid_range { + cell_gid_type begin = 0; + cell_gid_type end = 0; + cell_gid_type step = 1; + gid_range() = default; + gid_range(cell_gid_type b, cell_gid_type e): + begin(b), end(e), step(1) {} + gid_range(cell_gid_type b, cell_gid_type e, cell_gid_type s): + begin(b), end(e), step(s) {} +}; + // Policy for selecting a cell_lid_type from a range of possible values. enum class lid_selection_policy { diff --git a/arbor/include/arbor/domain_decomposition.hpp b/arbor/include/arbor/domain_decomposition.hpp index 2706cd2935..54f61ed7ce 100644 --- a/arbor/include/arbor/domain_decomposition.hpp +++ b/arbor/include/arbor/domain_decomposition.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -45,6 +46,7 @@ class ARB_ARBOR_API domain_decomposition { domain_decomposition& operator=(const domain_decomposition&) = default; int gid_domain(cell_gid_type gid) const; + cell_size_type index_on_domain(cell_gid_type gid) const; int num_domains() const; int domain_id() const; cell_size_type num_local_cells() const; @@ -54,10 +56,10 @@ class ARB_ARBOR_API domain_decomposition { const group_description& group(unsigned) const; private: - /// Return the domain id of cell with gid. + /// Return the domain id and index on domain of cell with gid. /// Supplied by the load balancing algorithm that generates the domain /// decomposition. - std::function gid_domain_; + std::function(cell_gid_type)> gid_domain_; /// Number of distributed domains int num_domains_; diff --git a/arbor/include/arbor/math.hpp b/arbor/include/arbor/math.hpp index d9483f3501..24ba307889 100644 --- a/arbor/include/arbor/math.hpp +++ b/arbor/include/arbor/math.hpp @@ -32,6 +32,19 @@ T constexpr area_circle(T r) { return pi * square(r); } +template >> +T constexpr pow(T base, U exp) { + if (exp == 0) return 1; + + const U exp_half = exp / 2; + if (2 * exp_half == exp) { + const auto r = ::arb::math::pow(base, exp_half); + return r * r; + } + + return base * ::arb::math::pow(base, exp - 1); +} + // Surface area of conic frustrum excluding the discs at each end, // with length L, end radii r1, r2. template diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp new file mode 100644 index 0000000000..262f318bdf --- /dev/null +++ b/arbor/include/arbor/network.hpp @@ -0,0 +1,331 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace arb { + +using network_hash_type = std::uint64_t; + +struct ARB_SYMBOL_VISIBLE network_site_info { + network_site_info(cell_gid_type gid, + cell_kind kind, + hash_type label, + mlocation location, + mpoint global_location): + gid(gid), + kind(kind), + label(label), + location(location), + global_location(global_location) {} + + cell_gid_type gid; + cell_kind kind; + hash_type label; + mlocation location; + mpoint global_location; + + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_site_info& s); +}; + +ARB_DEFINE_LEXICOGRAPHIC_ORDERING(network_site_info, + (a.gid, a.kind, a.label, a.location, a.global_location), + (b.gid, a.kind, b.label, b.location, b.global_location)) + +struct ARB_SYMBOL_VISIBLE network_connection_info { + network_site_info source, target; + double weight, delay; + + network_connection_info(network_site_info source, + network_site_info target, + double weight, + double delay): + source(source), + target(target), + weight(weight), + delay(delay) {} + + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, + const network_connection_info& s); +}; + +ARB_DEFINE_LEXICOGRAPHIC_ORDERING(network_connection_info, + (a.source, a.target, a.weight, a.delay), + (b.source, b.target, b.weight, b.delay)) + +struct network_selection_impl; + +struct network_value_impl; + +class ARB_SYMBOL_VISIBLE network_label_dict; + +class ARB_SYMBOL_VISIBLE network_selection; + +class ARB_SYMBOL_VISIBLE network_value { +public: + using custom_func_type = + std::function; + + network_value() { *this = network_value::scalar(0.0); } + + // Scalar value with conversion from double + network_value(double value) { *this = network_value::scalar(value); } + + // Scalar value. Will always return the same value given at construction. + static network_value scalar(double value); + + // A named value inside a network label dictionary + static network_value named(std::string name); + + // Distamce netweem source and target site + static network_value distance(double scale = 1.0); + + // Uniform random value in (range[0], range[1]]. + static network_value uniform_distribution(unsigned seed, const std::array& range); + + // Radom value from a normal distribution with given mean and standard deviation. + static network_value normal_distribution(unsigned seed, double mean, double std_deviation); + + // Radom value from a truncated normal distribution with given mean and standard deviation (of a + // non-truncated normal distribution), where the value is always in (range[0], range[1]]. + // Note: Values are generated by reject-accept method from a normal + // distribution. Low acceptance rate can leed to poor performance, for example with very small + // ranges or a mean far outside the range. + static network_value truncated_normal_distribution(unsigned seed, + double mean, + double std_deviation, + const std::array& range); + + // Custom value using the provided function "func". Repeated calls with the same arguments + // to "func" must yield the same result. + static network_value custom(custom_func_type func); + + static network_value add(network_value left, network_value right); + + static network_value sub(network_value left, network_value right); + + static network_value mul(network_value left, network_value right); + + static network_value div(network_value left, network_value right); + + static network_value exp(network_value v); + + static network_value log(network_value v); + + static network_value min(network_value left, network_value right); + + static network_value max(network_value left, network_value right); + + // if contained in selection, the true_value is used and the false_value otherwise. + static network_value if_else(network_selection cond, + network_value true_value, + network_value false_value); + + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_value& v); + +private: + network_value(std::shared_ptr impl); + + friend std::shared_ptr thingify(network_value v, + const network_label_dict& dict); + + std::shared_ptr impl_; +}; + +ARB_ARBOR_API inline network_value operator+(network_value a, network_value b) { + return network_value::add(std::move(a), std::move(b)); +} + +ARB_ARBOR_API inline network_value operator-(network_value a, network_value b) { + return network_value::sub(std::move(a), std::move(b)); +} + +ARB_ARBOR_API inline network_value operator*(network_value a, network_value b) { + return network_value::mul(std::move(a), std::move(b)); +} + +ARB_ARBOR_API inline network_value operator/(network_value a, network_value b) { + return network_value::div(std::move(a), std::move(b)); +} + +ARB_ARBOR_API inline network_value operator+(network_value a) { return a; } + +ARB_ARBOR_API inline network_value operator-(network_value a) { + return network_value::mul(-1.0, std::move(a)); +} + +class ARB_SYMBOL_VISIBLE network_selection { +public: + using custom_func_type = + std::function; + + network_selection() { *this = network_selection::none(); } + + // Select all + static network_selection all(); + + // Select none + static network_selection none(); + + // Named selection in the network label dictionary + static network_selection named(std::string name); + + // Only select connections between different cells + static network_selection inter_cell(); + + // Select connections with the given source cell kind + static network_selection source_cell_kind(cell_kind kind); + + // Select connections with the given target cell kind + static network_selection target_cell_kind(cell_kind kind); + + // Select connections with the given source label + static network_selection source_label(std::vector labels); + + // Select connections with the given target label + static network_selection target_label(std::vector labels); + + // Select connections with source cells matching the indices in the list + static network_selection source_cell(std::vector gids); + + // Select connections with source cells matching the indices in the range + static network_selection source_cell(gid_range range); + + // Select connections with target cells matching the indices in the list + static network_selection target_cell(std::vector gids); + + // Select connections with target cells matching the indices in the range + static network_selection target_cell(gid_range range); + + // Select connections that form a chain, such that source cell "i" is connected to the target + // cell "i+1" + static network_selection chain(std::vector gids); + + // Select connections that form a chain, such that source cell "i" is connected to the target + // cell "i+1" + static network_selection chain(gid_range range); + + // Select connections that form a reversed chain, such that source cell "i+1" is connected to + // the target cell "i" + static network_selection chain_reverse(gid_range range); + + // Select connections, that are selected by both "left" and "right" + static network_selection intersect(network_selection left, network_selection right); + + // Select connections, that are selected by either or both "left" and "right" + static network_selection join(network_selection left, network_selection right); + + // Select connections, that are selected by "left", unless selected by "right" + static network_selection difference(network_selection left, network_selection right); + + // Select connections, that are selected by "left" or "right", but not both + static network_selection symmetric_difference(network_selection left, network_selection right); + + // Invert the selection + static network_selection complement(network_selection s); + + // Random selection using the bernoulli random distribution with probability "p" between 0.0 + // and 1.0 + static network_selection random(unsigned seed, network_value p); + + // Custom selection using the provided function "func". Repeated calls with the same arguments + // to "func" must yield the same result. For gap junction selection, + // "func" must be symmetric (func(a,b) = func(b,a)). + static network_selection custom(custom_func_type func); + + // only select within given distance. This may enable more efficient sampling through an + // internal spatial data structure. + static network_selection distance_lt(double d); + + // only select if distance greater then given distance. This may enable more efficient sampling + // through an internal spatial data structure. + static network_selection distance_gt(double d); + + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_selection& s); + +private: + network_selection(std::shared_ptr impl); + + friend std::shared_ptr thingify(network_selection s, + const network_label_dict& dict); + + friend class network_value; + + std::shared_ptr impl_; +}; + +class ARB_SYMBOL_VISIBLE network_label_dict { +public: + using ns_map = std::unordered_map; + using nv_map = std::unordered_map; + + // Store a network selection under the given name + network_label_dict& set(const std::string& name, network_selection s); + + // Store a network value under the given name + network_label_dict& set(const std::string& name, network_value v); + + // Returns the stored network selection of the given name if it exists. None otherwise. + std::optional selection(const std::string& name) const; + + // Returns the stored network value of the given name if it exists. None otherwise. + std::optional value(const std::string& name) const; + + // All stored network selections + inline const ns_map& selections() const { return selections_; } + + // All stored network value + inline const nv_map& values() const { return values_; } + +private: + ns_map selections_; + nv_map values_; +}; + +// A complete network description required for processing +struct ARB_SYMBOL_VISIBLE network_description { + network_selection selection; + network_value weight; + network_value delay; + network_label_dict dict; +}; + +// Join two network selections +ARB_ARBOR_API network_selection join(network_selection left, network_selection right); + +// Join three or more network selections +ARB_ARBOR_API network_selection join(network_selection left, network_selection right); +template +network_selection join(network_selection l, network_selection r, Args... args) { + return join(join(std::move(l), std::move(r)), std::move(args)...); +} + +// Intersect two network selections +ARB_ARBOR_API network_selection join(network_selection left, network_selection right); +ARB_ARBOR_API network_selection intersect(network_selection left, network_selection right); + +// Intersect three or more network selections +ARB_ARBOR_API network_selection join(network_selection left, network_selection right); +template +network_selection intersect(network_selection l, network_selection r, Args... args) { + return intersect(intersect(std::move(l), std::move(r)), std::move(args)...); +} + +} // namespace arb diff --git a/arbor/include/arbor/network_generation.hpp b/arbor/include/arbor/network_generation.hpp new file mode 100644 index 0000000000..262cc426b6 --- /dev/null +++ b/arbor/include/arbor/network_generation.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include + +namespace arb { + +// Generate and return list of connections from the network description of the recipe. +// Does not include connections from the "connections_on" recipe function. +// Only returns connections with local cell targets as described in the domain decomposition. +ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec, + const context& ctx, + const domain_decomposition& dom_dec); + +// Generate and return list of ALL connections from the network description of the recipe. +// Does not include connections from the "connections_on" recipe function. +ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec); + +} // namespace arb diff --git a/arbor/include/arbor/recipe.hpp b/arbor/include/arbor/recipe.hpp index dbcba5fcde..a8bd8eb17e 100644 --- a/arbor/include/arbor/recipe.hpp +++ b/arbor/include/arbor/recipe.hpp @@ -1,12 +1,16 @@ #pragma once #include +#include #include #include -#include #include +#include #include +#include +#include +#include #include namespace arb { @@ -84,6 +88,10 @@ struct ARB_ARBOR_API has_synapses { virtual std::vector connections_on(cell_gid_type) const { return {}; } + // Optional network descriptions for generating cell connections + virtual std::optional network_description() const { + return std::nullopt; + }; virtual ~has_synapses() {} }; @@ -125,6 +133,8 @@ struct ARB_ARBOR_API recipe: public has_gap_junctions, has_probes, connectivity virtual cell_kind get_cell_kind(cell_gid_type) const = 0; // Global property type will be specific to given cell kind. virtual std::any get_global_properties(cell_kind) const { return std::any{}; }; + // Global cell isometry describing rotation and translation of the cell + virtual isometry get_cell_isometry(cell_gid_type gid) const { return isometry(); }; virtual ~recipe() {} }; diff --git a/arbor/include/arbor/simulation.hpp b/arbor/include/arbor/simulation.hpp index 86e30e104f..a2c5d69ecb 100644 --- a/arbor/include/arbor/simulation.hpp +++ b/arbor/include/arbor/simulation.hpp @@ -44,7 +44,7 @@ class ARB_ARBOR_API simulation { static simulation_builder create(recipe const &); - void update(const connectivity& rec); + void update(const recipe& rec); void reset(); diff --git a/arbor/network.cpp b/arbor/network.cpp new file mode 100644 index 0000000000..ba3ece380c --- /dev/null +++ b/arbor/network.cpp @@ -0,0 +1,1453 @@ +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "backends/rand_impl.hpp" +#include "network_impl.hpp" + +namespace arb { + +namespace { + +// Partial seed to use for network_value and network_selection generation. +// Different seed for each type to avoid unintentional correlation. +enum class network_seed : unsigned { + selection_random = 2058443, + value_uniform = 48202, + value_normal = 8405, + value_truncated_normal = 380237, +}; + +std::uint64_t location_hash(const mlocation& loc) { + const double l = static_cast(loc.branch) + loc.pos; + return *reinterpret_cast(&l); +} + +double uniform_rand(std::array seed, + const network_site_info& source, + const network_site_info& target) { + const cbprng::array_type seed_input = {{seed[0], seed[1], seed[2], seed[3]}}; + + const cbprng::array_type key = { + {source.gid, location_hash(source.location), target.gid, location_hash(target.location)}}; + cbprng::generator gen; + return r123::u01(gen(seed_input, key)[0]); +} + +double normal_rand(std::array seed, + const network_site_info& source, + const network_site_info& target) { + + using rand_type = r123::Threefry4x64; + const rand_type::ctr_type seed_input = {{seed[0], seed[1], seed[2], seed[3]}}; + + const rand_type::key_type key = { + {source.gid, location_hash(source.location), target.gid, location_hash(target.location)}}; + rand_type gen; + const auto rand_num = gen(seed_input, key); + + return r123::boxmuller(rand_num[0], rand_num[1]).x; +} + +struct network_selection_all_impl: public network_selection_impl { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return true; + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + void print(std::ostream& os) const override { os << "(all)"; } +}; + +struct network_selection_none_impl: public network_selection_impl { + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return false; + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return false; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return false; + } + + void print(std::ostream& os) const override { os << "(none)"; } +}; + +struct network_selection_source_cell_kind_impl: public network_selection_impl { + cell_kind select_kind; + + explicit network_selection_source_cell_kind_impl(cell_kind k): select_kind(k) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return source.kind == select_kind; + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return kind == select_kind; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + void print(std::ostream& os) const override { + os << "(source-cell-kind ("; + switch (select_kind) { + case arb::cell_kind::spike_source: os << "spike-source"; break; + case arb::cell_kind::cable: os << "cable"; break; + case arb::cell_kind::lif: os << "lif"; break; + case arb::cell_kind::benchmark: os << "benchmark"; break; + } + os << "-cell))"; + } +}; + +struct network_selection_target_cell_kind_impl: public network_selection_impl { + cell_kind select_kind; + + explicit network_selection_target_cell_kind_impl(cell_kind k): select_kind(k) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return target.kind == select_kind; + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return kind == select_kind; + } + + void print(std::ostream& os) const override { + os << "(target-cell-kind ("; + switch (select_kind) { + case arb::cell_kind::spike_source: os << "spike-source"; break; + case arb::cell_kind::cable: os << "cable"; break; + case arb::cell_kind::lif: os << "lif"; break; + case arb::cell_kind::benchmark: os << "benchmark"; break; + } + os << "-cell))"; + } +}; + +struct network_selection_source_label_impl: public network_selection_impl { + std::vector labels; + std::vector sorted_hashes; + + explicit network_selection_source_label_impl(std::vector labels_): + labels(std::move(labels_)) { + sorted_hashes.reserve(labels.size()); + for (const auto& l: labels) sorted_hashes.emplace_back(hash_value(l)); + + std::sort(sorted_hashes.begin(), sorted_hashes.end()); + } + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return std::binary_search(sorted_hashes.begin(), sorted_hashes.end(), source.label); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return std::binary_search(sorted_hashes.begin(), sorted_hashes.end(), label); + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + void print(std::ostream& os) const override { + os << "(source-label"; + for (const auto& l: labels) { os << " \"" << l << "\""; } + os << ")"; + } +}; + +struct network_selection_target_label_impl: public network_selection_impl { + std::vector labels; + std::vector sorted_hashes; + + explicit network_selection_target_label_impl(std::vector labels_): + labels(std::move(labels_)) { + sorted_hashes.reserve(labels.size()); + for (const auto& l: labels) sorted_hashes.emplace_back(hash_value(l)); + + std::sort(sorted_hashes.begin(), sorted_hashes.end()); + } + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return std::binary_search(sorted_hashes.begin(), sorted_hashes.end(), target.label); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return std::binary_search(sorted_hashes.begin(), sorted_hashes.end(), label); + } + + void print(std::ostream& os) const override { + os << "(target-label"; + for (const auto& l: labels) { os << " \"" << l << "\""; } + os << ")"; + } +}; + +struct network_selection_source_cell_impl: public network_selection_impl { + std::vector sorted_gids; + + explicit network_selection_source_cell_impl(std::vector gids): + sorted_gids(std::move(gids)) { + std::sort(sorted_gids.begin(), sorted_gids.end()); + } + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), source.gid); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + void print(std::ostream& os) const override { + os << "(source-cell"; + for (const auto& g: sorted_gids) { os << " " << g; } + os << ")"; + } +}; + +struct network_selection_source_cell_range_impl: public network_selection_impl { + cell_gid_type gid_begin, gid_end, step; + + network_selection_source_cell_range_impl(gid_range r): + gid_begin(r.begin), + gid_end(r.end), + step(r.step) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return source.gid >= gid_begin && source.gid < gid_end && + !((source.gid - gid_begin) % step); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return gid >= gid_begin && gid < gid_end && !((gid - gid_begin) % step); + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + void print(std::ostream& os) const override { + os << "(source-cell (gid-range " << gid_begin << " " << gid_end << " " << step << "))"; + } +}; + +struct network_selection_target_cell_impl: public network_selection_impl { + std::vector sorted_gids; + + network_selection_target_cell_impl(std::vector gids): + sorted_gids(std::move(gids)) { + std::sort(sorted_gids.begin(), sorted_gids.end()); + } + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), target.gid); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); + } + + void print(std::ostream& os) const override { + os << "(target-cell"; + for (const auto& g: sorted_gids) { os << " " << g; } + os << ")"; + } +}; + +struct network_selection_target_cell_range_impl: public network_selection_impl { + cell_gid_type gid_begin, gid_end, step; + + network_selection_target_cell_range_impl(gid_range r): + gid_begin(r.begin), + gid_end(r.end), + step(r.step) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return target.gid >= gid_begin && target.gid < gid_end && + !((target.gid - gid_begin) % step); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return gid >= gid_begin && gid < gid_end && !((gid - gid_begin) % step); + } + + void print(std::ostream& os) const override { + os << "(target-cell (gid-range " << gid_begin << " " << gid_end << " " << step << "))"; + } +}; + +struct network_selection_chain_impl: public network_selection_impl { + std::vector gids; // preserved order of ring + std::vector sorted_gids; + network_selection_chain_impl(std::vector gids): gids(std::move(gids)) { + sorted_gids = this->gids; // copy + std::sort(sorted_gids.begin(), sorted_gids.end()); + } + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + if (gids.empty()) return false; + + // gids size always > 0 frome here on + + // First check if both are part of ring + if (!std::binary_search(sorted_gids.begin(), sorted_gids.end(), source.gid) || + !std::binary_search(sorted_gids.begin(), sorted_gids.end(), target.gid)) + return false; + + for (std::size_t i = 0; i < gids.size() - 1; ++i) { + // return true if neighbors in gids list + if ((source.gid == gids[i] && target.gid == gids[i + 1])) return true; + } + + return false; + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return !sorted_gids.empty() && + std::binary_search(sorted_gids.begin(), sorted_gids.end() - 1, gid); + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return !sorted_gids.empty() && + std::binary_search(sorted_gids.begin() + 1, sorted_gids.end(), gid); + } + + void print(std::ostream& os) const override { + os << "(chain"; + for (const auto& g: gids) { os << " " << g; } + os << ")"; + } +}; + +struct network_selection_chain_range_impl: public network_selection_impl { + cell_gid_type gid_begin, gid_end, step; + + network_selection_chain_range_impl(gid_range r): + gid_begin(r.begin), + gid_end(r.end), + step(r.step) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + if (source.gid < gid_begin || source.gid >= gid_end || target.gid < gid_begin || + target.gid >= gid_end) + return false; + + return source.gid + step == target.gid && !((source.gid - gid_begin) % step); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + // Return false if outside range or if equal to last element, which cannot be a source + if (gid < gid_begin || gid >= gid_end - 1) return false; + return !((gid - gid_begin) % step); + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + // Return false if outside range or if equal to first element, which cannot be a target + if (gid <= gid_begin || gid >= gid_end) return false; + return !((gid - gid_begin) % step); + } + + void print(std::ostream& os) const override { + os << "(chain (gid-range " << gid_begin << " " << gid_end << " " << step << "))"; + } +}; + +struct network_selection_reverse_chain_range_impl: public network_selection_impl { + cell_gid_type gid_begin, gid_end, step; + + network_selection_reverse_chain_range_impl(gid_range r): + gid_begin(r.begin), + gid_end(r.end), + step(r.step) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + if (source.gid < gid_begin || source.gid >= gid_end || target.gid < gid_begin || + target.gid >= gid_end) + return false; + + return target.gid + step == source.gid && !((source.gid - gid_begin) % step); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + // Return false if outside range or if equal to first element, which cannot be a source + if (gid <= gid_begin || gid >= gid_end) return false; + return !((gid - gid_begin) % step); + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + // Return false if outside range or if equal to last element, which cannot be a target + if (gid < gid_begin || gid >= gid_end - 1) return false; + return !((gid - gid_begin) % step); + } + + void print(std::ostream& os) const override { + os << "(chain-reverse (gid-range " << gid_begin << " " << gid_end << " " << step << "))"; + } +}; + +struct network_selection_complement_impl: public network_selection_impl { + std::shared_ptr selection; + + explicit network_selection_complement_impl(std::shared_ptr s): + selection(std::move(s)) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return !selection->select_connection(source, target); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; // cannot exclude any because source selection cannot be complemented without + // knowing selection criteria. + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; // cannot exclude any because target selection cannot be complemented + // without knowing selection criteria. + } + + void initialize(const network_label_dict& dict) override { selection->initialize(dict); }; + + void print(std::ostream& os) const override { + os << "(complement "; + selection->print(os); + os << ")"; + } +}; + +struct network_selection_named_impl: public network_selection_impl { + using impl_pointer_type = std::shared_ptr; + + impl_pointer_type selection; + std::string selection_name; + + explicit network_selection_named_impl(std::string name): selection_name(std::move(name)) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + if (!selection) + throw arbor_internal_error("Trying to use unitialized named network selection."); + return selection->select_connection(source, target); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + if (!selection) + throw arbor_internal_error("Trying to use unitialized named network selection."); + return selection->select_source(kind, gid, label); + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + if (!selection) + throw arbor_internal_error("Trying to use unitialized named network selection."); + return selection->select_target(kind, gid, label); + } + + void initialize(const network_label_dict& dict) override { + auto s = dict.selection(selection_name); + if (!s.has_value()) + throw arbor_exception( + std::string("Network selection with label \"") + selection_name + "\" not found."); + selection = thingify(s.value(), dict); + }; + + void print(std::ostream& os) const override { + os << "(network-selection \"" << selection_name << "\")"; + } +}; + +struct network_selection_inter_cell_impl: public network_selection_impl { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return source.gid != target.gid; + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + void print(std::ostream& os) const override { os << "(inter-cell)"; } +}; + +struct network_selection_custom_impl: public network_selection_impl { + network_selection::custom_func_type func; + + explicit network_selection_custom_impl(network_selection::custom_func_type f): + func(std::move(f)) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return func(source, target); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + void print(std::ostream& os) const override { os << "(custom-network-selection)"; } +}; + +struct network_selection_distance_lt_impl: public network_selection_impl { + double d; + + explicit network_selection_distance_lt_impl(double d): d(d) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return distance(source.global_location, target.global_location) < d; + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + std::optional max_distance() const override { return d; } + + void print(std::ostream& os) const override { os << "(distance-lt " << d << ")"; } +}; + +struct network_selection_distance_gt_impl: public network_selection_impl { + double d; + + explicit network_selection_distance_gt_impl(double d): d(d) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return distance(source.global_location, target.global_location) > d; + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + void print(std::ostream& os) const override { os << "(distance-gt " << d << ")"; } +}; + +struct network_selection_random_impl: public network_selection_impl { + unsigned seed; + + network_value p_value; + std::shared_ptr probability; // may be null if unitialize(...) not called + + network_selection_random_impl(unsigned seed, network_value p): + seed(seed), + p_value(std::move(p)) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + if (!probability) + throw arbor_internal_error("Trying to use unitialized named network selection."); + const auto r = uniform_rand( + {unsigned(network_seed::selection_random), seed, seed + 1, seed + 2}, source, target); + const auto p = (probability->get(source, target)); + return r < p; + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return true; + } + + void initialize(const network_label_dict& dict) override { + probability = thingify(p_value, dict); + }; + + void print(std::ostream& os) const override { + os << "(random " << seed << " "; + os << p_value; + os << ")"; + } +}; + +struct network_selection_intersect_impl: public network_selection_impl { + std::shared_ptr left, right; + + network_selection_intersect_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return left->select_connection(source, target) && right->select_connection(source, target); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return left->select_source(kind, gid, label) && right->select_source(kind, gid, label); + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return left->select_target(kind, gid, label) && right->select_target(kind, gid, label); + } + + std::optional max_distance() const override { + const auto d_left = left->max_distance(); + const auto d_right = right->max_distance(); + + if (d_left && d_right) return std::min(d_left.value(), d_right.value()); + if (d_left) return d_left.value(); + if (d_right) return d_right.value(); + + return std::nullopt; + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(intersect "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_selection_join_impl: public network_selection_impl { + std::shared_ptr left, right; + + network_selection_join_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return left->select_connection(source, target) || right->select_connection(source, target); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return left->select_source(kind, gid, label) || right->select_source(kind, gid, label); + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return left->select_target(kind, gid, label) || right->select_target(kind, gid, label); + } + + std::optional max_distance() const override { + const auto d_left = left->max_distance(); + const auto d_right = right->max_distance(); + + if (d_left && d_right) return std::max(d_left.value(), d_right.value()); + + return std::nullopt; + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(join "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_selection_symmetric_difference_impl: public network_selection_impl { + std::shared_ptr left, right; + + network_selection_symmetric_difference_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return left->select_connection(source, target) ^ right->select_connection(source, target); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return left->select_source(kind, gid, label) || right->select_source(kind, gid, label); + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return left->select_target(kind, gid, label) || right->select_target(kind, gid, label); + } + + std::optional max_distance() const override { + const auto d_left = left->max_distance(); + const auto d_right = right->max_distance(); + + if (d_left && d_right) return std::max(d_left.value(), d_right.value()); + + return std::nullopt; + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(symmetric-difference "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_selection_difference_impl: public network_selection_impl { + std::shared_ptr left, right; + + network_selection_difference_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return left->select_connection(source, target) && + !(right->select_connection(source, target)); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return left->select_source(kind, gid, label); + } + + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return left->select_target(kind, gid, label); + } + + std::optional max_distance() const override { + const auto d_left = left->max_distance(); + + if (d_left) return d_left.value(); + + return std::nullopt; + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(difference "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_value_scalar_impl: public network_value_impl { + double value; + + network_value_scalar_impl(double v): value(v) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + return value; + } + + void print(std::ostream& os) const override { os << "(scalar " << value << ")"; } +}; + +struct network_value_distance_impl: public network_value_impl { + double scale; + + network_value_distance_impl(double s): scale(s) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + return scale * distance(source.global_location, target.global_location); + } + + void print(std::ostream& os) const override { os << "(distance " << scale << ")"; } +}; + +struct network_value_uniform_distribution_impl: public network_value_impl { + unsigned seed = 0; + std::array range; + + network_value_uniform_distribution_impl(unsigned rand_seed, const std::array& r): + seed(rand_seed), + range(r) { + if (range[0] >= range[1]) + throw std::invalid_argument("Uniform distribution: invalid range"); + } + + double get(const network_site_info& source, const network_site_info& target) const override { + if (range[0] > range[1]) return range[1]; + + // random number between 0 and 1 + const auto rand_num = uniform_rand( + {unsigned(network_seed::value_uniform), seed, seed + 1, seed + 2}, source, target); + + return (range[1] - range[0]) * rand_num + range[0]; + } + + void print(std::ostream& os) const override { + os << "(uniform-distribution " << seed << " " << range[0] << " " << range[1] << ")"; + } +}; + +struct network_value_normal_distribution_impl: public network_value_impl { + unsigned seed = 0; + double mean = 0.0; + double std_deviation = 1.0; + + network_value_normal_distribution_impl(unsigned rand_seed, double mean_, double std_deviation_): + seed(rand_seed), + mean(mean_), + std_deviation(std_deviation_) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + return mean + + std_deviation * + normal_rand({unsigned(network_seed::value_normal), seed, seed + 1, seed + 2}, + source, + target); + } + + void print(std::ostream& os) const override { + os << "(normal-distribution " << seed << " " << mean << " " << std_deviation << ")"; + } +}; + +struct network_value_truncated_normal_distribution_impl: public network_value_impl { + unsigned seed = 0; + double mean = 0.0; + double std_deviation = 1.0; + std::array range; + + network_value_truncated_normal_distribution_impl(unsigned rand_seed, + double mean_, + double std_deviation_, + const std::array& range_): + seed(rand_seed), + mean(mean_), + std_deviation(std_deviation_), + range(range_) { + if (range[0] >= range[1]) + throw std::invalid_argument("Truncated normal distribution: invalid range"); + } + + double get(const network_site_info& source, const network_site_info& target) const override { + + double value = 0.0; + + auto dynamic_seed = seed; + do { + value = + mean + std_deviation * normal_rand({unsigned(network_seed::value_truncated_normal), + dynamic_seed, + dynamic_seed + 1, + dynamic_seed + 2}, + source, + target); + ++dynamic_seed; + } while (!(value > range[0] && value <= range[1])); + + return value; + } + + void print(std::ostream& os) const override { + os << "(truncated-normal-distribution " << seed << " " << mean << " " << std_deviation + << " " << range[0] << " " << range[1] << ")"; + } +}; + +struct network_value_custom_impl: public network_value_impl { + network_value::custom_func_type func; + + network_value_custom_impl(network_value::custom_func_type f): func(std::move(f)) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + return func(source, target); + } + + void print(std::ostream& os) const override { os << "(custom-network-value)"; } +}; + +struct network_value_named_impl: public network_value_impl { + using impl_pointer_type = std::shared_ptr; + + impl_pointer_type value; + std::string value_name; + + explicit network_value_named_impl(std::string name): value_name(std::move(name)) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + if (!value) throw arbor_internal_error("Trying to use unitialized named network value."); + return value->get(source, target); + } + + void initialize(const network_label_dict& dict) override { + auto v = dict.value(value_name); + if (!v.has_value()) + throw arbor_exception( + std::string("Network value with label \"") + value_name + "\" not found."); + value = thingify(v.value(), dict); + }; + + void print(std::ostream& os) const override { + os << "(network-value \"" << value_name << "\")"; + } +}; + +struct network_value_add_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_add_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + return left->get(source, target) + right->get(source, target); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(add "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_value_mul_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_mul_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + return left->get(source, target) * right->get(source, target); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(mul "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_value_sub_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_sub_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + return left->get(source, target) - right->get(source, target); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(sub "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_value_div_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_div_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + const auto v_right = right->get(source, target); + if (!v_right) throw arbor_exception("network_value: division by 0."); + return left->get(source, target) / right->get(source, target); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(div "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_value_max_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_max_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + return std::max(left->get(source, target), right->get(source, target)); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(max "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_value_min_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_min_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + return std::min(left->get(source, target), right->get(source, target)); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(min "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_value_exp_impl: public network_value_impl { + std::shared_ptr value; + + network_value_exp_impl(std::shared_ptr v): value(std::move(v)) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + return std::exp(value->get(source, target)); + } + + void initialize(const network_label_dict& dict) override { value->initialize(dict); }; + + void print(std::ostream& os) const override { + os << "(exp "; + value->print(os); + os << ")"; + } +}; + +struct network_value_log_impl: public network_value_impl { + std::shared_ptr value; + + network_value_log_impl(std::shared_ptr v): value(std::move(v)) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + const auto v = value->get(source, target); + if (v <= 0.0) throw arbor_exception("network_value: log of value <= 0.0."); + return std::log(value->get(source, target)); + } + + void initialize(const network_label_dict& dict) override { value->initialize(dict); }; + + void print(std::ostream& os) const override { + os << "(log "; + value->print(os); + os << ")"; + } +}; + +struct network_value_if_else_impl: public network_value_impl { + std::shared_ptr cond; + std::shared_ptr true_value; + std::shared_ptr false_value; + + network_value_if_else_impl(std::shared_ptr cond, + std::shared_ptr true_value, + std::shared_ptr false_value): + cond(std::move(cond)), + true_value(std::move(true_value)), + false_value(std::move(false_value)) {} + + double get(const network_site_info& source, const network_site_info& target) const override { + if (cond->select_connection(source, target)) return true_value->get(source, target); + return false_value->get(source, target); + } + + void initialize(const network_label_dict& dict) override { + cond->initialize(dict); + true_value->initialize(dict); + false_value->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(if-else "; + cond->print(os); + os << " "; + true_value->print(os); + os << " "; + false_value->print(os); + os << ")"; + } +}; + +} // namespace + +network_selection::network_selection(std::shared_ptr impl): + impl_(std::move(impl)) {} + +network_selection network_selection::intersect(network_selection left, network_selection right) { + return network_selection(std::make_shared( + std::move(left.impl_), std::move(right.impl_))); +} + +network_selection network_selection::join(network_selection left, network_selection right) { + return network_selection(std::make_shared( + std::move(left.impl_), std::move(right.impl_))); +} + +network_selection network_selection::symmetric_difference(network_selection left, + network_selection right) { + return network_selection(std::make_shared( + std::move(left.impl_), std::move(right.impl_))); +} + +network_selection network_selection::difference(network_selection left, network_selection right) { + return network_selection(std::make_shared( + std::move(left.impl_), std::move(right.impl_))); +} + +network_selection network_selection::all() { + return network_selection(std::make_shared()); +} + +network_selection network_selection::none() { + return network_selection(std::make_shared()); +} + +network_selection network_selection::named(std::string name) { + return network_selection(std::make_shared(std::move(name))); +} + +network_selection network_selection::source_cell_kind(cell_kind kind) { + return network_selection(std::make_shared(kind)); +} + +network_selection network_selection::target_cell_kind(cell_kind kind) { + return network_selection(std::make_shared(kind)); +} + +network_selection network_selection::source_label(std::vector labels) { + return network_selection( + std::make_shared(std::move(labels))); +} + +network_selection network_selection::target_label(std::vector labels) { + return network_selection( + std::make_shared(std::move(labels))); +} + +network_selection network_selection::source_cell(std::vector gids) { + return network_selection(std::make_shared(std::move(gids))); +} + +network_selection network_selection::source_cell(gid_range range) { + return network_selection(std::make_shared(range)); +} + +network_selection network_selection::target_cell(std::vector gids) { + return network_selection(std::make_shared(std::move(gids))); +} + +network_selection network_selection::target_cell(gid_range range) { + return network_selection(std::make_shared(range)); +} + +network_selection network_selection::chain(std::vector gids) { + return network_selection(std::make_shared(std::move(gids))); +} + +network_selection network_selection::chain(gid_range range) { + return network_selection(std::make_shared(range)); +} + +network_selection network_selection::chain_reverse(gid_range range) { + return network_selection(std::make_shared(range)); +} + +network_selection network_selection::complement(network_selection s) { + return network_selection( + std::make_shared(std::move(s.impl_))); +} + +network_selection network_selection::inter_cell() { + return network_selection(std::make_shared()); +} + +network_selection network_selection::random(unsigned seed, network_value p) { + return network_selection(std::make_shared(seed, std::move(p))); +} + +network_selection network_selection::custom(custom_func_type func) { + return network_selection(std::make_shared(std::move(func))); +} + +network_selection network_selection::distance_lt(double d) { + return network_selection(std::make_shared(d)); +} + +network_selection network_selection::distance_gt(double d) { + return network_selection(std::make_shared(d)); +} + +network_value::network_value(std::shared_ptr impl): impl_(std::move(impl)) {} + +network_value network_value::scalar(double value) { + return network_value(std::make_shared(value)); +} + +network_value network_value::distance(double scale) { + return network_value(std::make_shared(scale)); +} + +network_value network_value::uniform_distribution(unsigned seed, + const std::array& range) { + return network_value(std::make_shared(seed, range)); +} + +network_value network_value::normal_distribution(unsigned seed, double mean, double std_deviation) { + return network_value( + std::make_shared(seed, mean, std_deviation)); +} + +network_value network_value::truncated_normal_distribution(unsigned seed, + double mean, + double std_deviation, + const std::array& range) { + return network_value(std::make_shared( + seed, mean, std_deviation, range)); +} + +network_value network_value::custom(custom_func_type func) { + return network_value(std::make_shared(std::move(func))); +} + +network_value network_value::named(std::string name) { + return network_value(std::make_shared(std::move(name))); +} + +network_label_dict& network_label_dict::set(const std::string& name, network_selection s) { + selections_.insert_or_assign(name, std::move(s)); + return *this; +} + +network_label_dict& network_label_dict::set(const std::string& name, network_value v) { + values_.insert_or_assign(name, std::move(v)); + return *this; +} + +network_value network_value::add(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + +network_value network_value::sub(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + +network_value network_value::mul(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + +network_value network_value::div(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + +network_value network_value::exp(network_value v) { + return network_value(std::make_shared(std::move(v.impl_))); +} + +network_value network_value::log(network_value v) { + return network_value(std::make_shared(std::move(v.impl_))); +} + +network_value network_value::min(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + +network_value network_value::max(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + +network_value network_value::if_else(network_selection cond, + network_value true_value, + network_value false_value) { + return network_value(std::make_shared( + std::move(cond.impl_), std::move(true_value.impl_), std::move(false_value.impl_))); +} + +std::optional network_label_dict::selection(const std::string& name) const { + auto it = selections_.find(name); + if (it != selections_.end()) return it->second; + + return std::nullopt; +} + +std::optional network_label_dict::value(const std::string& name) const { + auto it = values_.find(name); + if (it != values_.end()) return it->second; + + return std::nullopt; +} + +ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_selection& s) { + if (s.impl_) s.impl_->print(os); + return os; +} + +ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_value& v) { + if (v.impl_) v.impl_->print(os); + return os; +} + +ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_site_info& s) { + + os << ""; + return os; +} + +ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_connection_info& s) { + + os << ""; + return os; +} + +ARB_ARBOR_API network_selection join(network_selection left, network_selection right) { + return network_selection::join(std::move(left), std::move(right)); +} + +ARB_ARBOR_API network_selection intersect(network_selection left, network_selection right) { + return network_selection::intersect(std::move(left), std::move(right)); +} + +} // namespace arb diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp new file mode 100644 index 0000000000..11d138f089 --- /dev/null +++ b/arbor/network_impl.cpp @@ -0,0 +1,308 @@ +#include "network_impl.hpp" +#include "cell_group_factory.hpp" +#include "communication/distributed_for_each.hpp" +#include "label_resolution.hpp" +#include "network_impl.hpp" +#include "threading/threading.hpp" +#include "util/rangeutil.hpp" +#include "util/spatial_tree.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace arb { + +namespace { +struct network_site_info_extended { + network_site_info_extended(network_site_info info, cell_lid_type lid): + info(std::move(info)), + lid(lid) {} + + network_site_info info; + cell_lid_type lid; +}; + +void push_back(const domain_decomposition& dom_dec, + std::vector& vec, + const network_site_info_extended& source, + const network_site_info_extended& target, + double weight, + double delay) { + vec.emplace_back(connection{{source.info.gid, source.lid}, + target.lid, + (float)weight, + (float)delay, + dom_dec.index_on_domain(target.info.gid)}); +} + +void push_back(const domain_decomposition& dom_dec, + std::vector& vec, + const network_site_info_extended& source, + const network_site_info_extended& target, + double weight, + double delay) { + vec.emplace_back(source.info, target.info, weight, delay); +} + +template +std::vector generate_network_connections_impl(const recipe& rec, + const context& ctx, + const domain_decomposition& dom_dec) { + const auto description_opt = rec.network_description(); + if (!description_opt.has_value()) return {}; + + const distributed_context& distributed = *(ctx->distributed); + + const auto& description = description_opt.value(); + + const auto selection_ptr = thingify(description.selection, description.dict); + const auto weight_ptr = thingify(description.weight, description.dict); + const auto delay_ptr = thingify(description.delay, description.dict); + + const auto& selection = *selection_ptr; + const auto& weight = *weight_ptr; + const auto& delay = *delay_ptr; + + std::unordered_map> gids_by_kind; + + for (const auto& group: dom_dec.groups()) { + auto& gids = gids_by_kind[group.kind]; + for (const auto& gid: group.gids) { gids.emplace_back(gid); } + } + + const auto num_batches = ctx->thread_pool->get_num_threads(); + std::vector> src_site_batches(num_batches); + std::vector> tgt_site_batches(num_batches); + + for (const auto& [kind, gids]: gids_by_kind) { + const auto batch_size = (gids.size() + num_batches - 1) / num_batches; + // populate network sites for source and target + if (kind == cell_kind::cable) { + const auto& cable_gids = gids; + threading::parallel_for::apply( + 0, cable_gids.size(), batch_size, ctx->thread_pool.get(), [&](int i) { + const auto batch_idx = ctx->thread_pool->get_current_thread_id().value(); + auto& src_sites = src_site_batches[batch_idx]; + auto& tgt_sites = tgt_site_batches[batch_idx]; + const auto gid = cable_gids[i]; + const auto kind = rec.get_cell_kind(gid); + // We need access to morphology, so the cell is create directly + cable_cell cell; + try { + cell = util::any_cast(rec.get_cell_description(gid)); + } + catch (std::bad_any_cast&) { + throw bad_cell_description(kind, gid); + } + + auto lid_to_label = [](const std::unordered_multimap& map, + cell_lid_type lid) -> hash_type { + for (const auto& [label, range]: map) { + if (lid >= range.begin && lid < range.end) return label; + } + throw arbor_internal_error("unkown lid"); + }; + + place_pwlin location_resolver(cell.morphology(), rec.get_cell_isometry(gid)); + + // check all synapses of cell for potential target + + for (const auto& [_, placed_synapses]: cell.synapses()) { + for (const auto& p_syn: placed_synapses) { + const auto& label = lid_to_label(cell.synapse_ranges(), p_syn.lid); + + if (selection.select_target(cell_kind::cable, gid, label)) { + const mpoint point = location_resolver.at(p_syn.loc); + tgt_sites.emplace_back( + network_site_info{ + gid, cell_kind::cable, label, p_syn.loc, point}, + p_syn.lid); + } + } + } + + // check all detectors of cell for potential source + for (const auto& p_det: cell.detectors()) { + const auto& label = lid_to_label(cell.detector_ranges(), p_det.lid); + if (selection.select_source(cell_kind::cable, gid, label)) { + const mpoint point = location_resolver.at(p_det.loc); + src_sites.emplace_back( + network_site_info{gid, cell_kind::cable, label, p_det.loc, point}, + p_det.lid); + } + } + }); + } + else { + // Assuming all other cell types do not have a morphology. We can use label + // resolution through factory and set local position to 0. + auto factory = cell_kind_implementation(kind, backend_kind::multicore, *ctx, 0); + + // We only need the label ranges + cell_label_range sources, targets; + std::ignore = factory(gids, rec, sources, targets); + + auto& src_sites = src_site_batches[0]; + auto& tgt_sites = tgt_site_batches[0]; + + std::size_t source_label_offset = 0; + std::size_t target_label_offset = 0; + for (std::size_t i = 0; i < gids.size(); ++i) { + const auto gid = gids[i]; + const auto iso = rec.get_cell_isometry(gid); + const auto point = iso.apply(mpoint{0.0, 0.0, 0.0, 0.0}); + const auto num_source_labels = sources.sizes.at(i); + const auto num_target_labels = targets.sizes.at(i); + + // Iterate over each source label for current gid + for (std::size_t j = source_label_offset; + j < source_label_offset + num_source_labels; + ++j) { + const auto& label = sources.labels.at(j); + const auto& range = sources.ranges.at(j); + for (auto lid = range.begin; lid < range.end; ++lid) { + if (selection.select_source(kind, gid, label)) { + src_sites.emplace_back( + network_site_info{gid, kind, label, mlocation{0, 0.0}, point}, lid); + } + } + } + + // Iterate over each target label for current gid + for (std::size_t j = target_label_offset; + j < target_label_offset + num_target_labels; + ++j) { + const auto& label = targets.labels.at(j); + const auto& range = targets.ranges.at(j); + for (auto lid = range.begin; lid < range.end; ++lid) { + if (selection.select_target(kind, gid, label)) { + tgt_sites.emplace_back( + network_site_info{gid, kind, label, mlocation{0, 0.0}, point}, lid); + } + } + } + + source_label_offset += num_source_labels; + target_label_offset += num_target_labels; + } + } + } + + auto src_sites = std::move(src_site_batches.back()); + src_site_batches.pop_back(); + for (const auto& batch: src_site_batches) + src_sites.insert(src_sites.end(), batch.begin(), batch.end()); + + auto tgt_sites = std::move(tgt_site_batches.back()); + tgt_site_batches.pop_back(); + for (const auto& batch: tgt_site_batches) + tgt_sites.insert(tgt_sites.end(), batch.begin(), batch.end()); + + // create octree + const std::size_t max_depth = selection.max_distance().has_value() ? 10 : 1; + const std::size_t max_leaf_size = 100; + spatial_tree local_tgt_tree(max_depth, + max_leaf_size, + std::move(tgt_sites), + [](const network_site_info_extended& ex) + -> spatial_tree::point_type { + return { + ex.info.global_location.x, ex.info.global_location.y, ex.info.global_location.z}; + }); + + // select connections + std::vector> connection_batches(num_batches); + + auto sample_sources = [&](const util::range& source_range) { + const auto batch_size = (source_range.size() + num_batches - 1) / num_batches; + threading::parallel_for::apply( + 0, source_range.size(), batch_size, ctx->thread_pool.get(), [&](int i) { + const auto& source = source_range[i]; + const auto batch_idx = ctx->thread_pool->get_current_thread_id().value(); + auto& connections = connection_batches[batch_idx]; + + auto sample = [&](const network_site_info_extended& target) { + if (selection.select_connection(source.info, target.info)) { + const auto w = weight.get(source.info, target.info); + const auto d = delay.get(source.info, target.info); + + push_back(dom_dec, connections, source, target, w, d); + } + }; + + if (selection.max_distance().has_value()) { + const double d = selection.max_distance().value(); + local_tgt_tree.bounding_box_for_each( + decltype(local_tgt_tree)::point_type{source.info.global_location.x - d, + source.info.global_location.y - d, + source.info.global_location.z - d}, + decltype(local_tgt_tree)::point_type{source.info.global_location.x + d, + source.info.global_location.y + d, + source.info.global_location.z + d}, + sample); + } + else { local_tgt_tree.for_each(sample); } + }); + }; + + distributed_for_each(sample_sources, distributed, util::range_view(src_sites)); + + // concatenate + auto connections = std::move(connection_batches.front()); + for (std::size_t i = 1; i < connection_batches.size(); ++i) { + connections.insert( + connections.end(), connection_batches[i].begin(), connection_batches[i].end()); + } + return connections; +} + +} // namespace + +std::vector generate_connections(const recipe& rec, + const context& ctx, + const domain_decomposition& dom_dec) { + return generate_network_connections_impl(rec, ctx, dom_dec); +} + +ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec, + const context& ctx, + const domain_decomposition& dom_dec) { + auto connections = generate_network_connections_impl(rec, ctx, dom_dec); + + // generated connections may have different order each time due to multi-threading. + // Sort before returning to user for reproducibility. + std::sort(connections.begin(), connections.end()); + + return connections; +} + +ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec) { + auto ctx = arb::make_context(); + auto decomp = arb::partition_load_balance(rec, ctx); + return generate_network_connections(rec, ctx, decomp); +} + +} // namespace arb diff --git a/arbor/network_impl.hpp b/arbor/network_impl.hpp new file mode 100644 index 0000000000..65c62767ef --- /dev/null +++ b/arbor/network_impl.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "connection.hpp" +#include "distributed_context.hpp" +#include "label_resolution.hpp" + +namespace arb { + +struct network_selection_impl { + virtual std::optional max_distance() const { return std::nullopt; } + + virtual bool select_connection(const network_site_info& source, + const network_site_info& target) const = 0; + + virtual bool select_source(cell_kind kind, cell_gid_type gid, hash_type tag) const = 0; + + virtual bool select_target(cell_kind kind, cell_gid_type gid, hash_type tag) const = 0; + + virtual void initialize(const network_label_dict& dict){}; + + virtual void print(std::ostream& os) const = 0; + + virtual ~network_selection_impl() = default; +}; + +inline std::shared_ptr thingify(network_selection s, + const network_label_dict& dict) { + s.impl_->initialize(dict); + return s.impl_; +} + +struct network_value_impl { + virtual double get(const network_site_info& source, const network_site_info& target) const = 0; + + virtual void initialize(const network_label_dict& dict){}; + + virtual void print(std::ostream& os) const = 0; + + virtual ~network_value_impl() = default; +}; + +inline std::shared_ptr thingify(network_value v, + const network_label_dict& dict) { + v.impl_->initialize(dict); + return v.impl_; +} + +std::vector generate_connections(const recipe& rec, + const context& ctx, + const domain_decomposition& dom_dec); + +} // namespace arb diff --git a/arbor/simulation.cpp b/arbor/simulation.cpp index 55c027710a..998963790b 100644 --- a/arbor/simulation.cpp +++ b/arbor/simulation.cpp @@ -1,13 +1,14 @@ #include #include -#include #include #include #include +#include #include #include #include +#include #include #include "epoch.hpp" @@ -86,7 +87,7 @@ class simulation_state { public: simulation_state(const recipe& rec, const domain_decomposition& decomp, context ctx, arb_seed_type seed); - void update(const connectivity& rec); + void update(const recipe& rec); void reset(); @@ -278,13 +279,13 @@ simulation_state::simulation_state( PL(); PE(init:simulation:comm); - communicator_ = communicator(rec, ddc_, *ctx_); + communicator_ = communicator(rec, ddc_, ctx_); PL(); update(rec); epoch_.reset(); } -void simulation_state::update(const connectivity& rec) { +void simulation_state::update(const recipe& rec) { communicator_.update_connections(rec, ddc_, source_resolution_map_, target_resolution_map_); // Use half minimum delay of the network for max integration interval. t_interval_ = min_delay()/2; @@ -575,7 +576,7 @@ void simulation::reset() { impl_->reset(); } -void simulation::update(const connectivity& rec) { impl_->update(rec); } +void simulation::update(const recipe& rec) { impl_->update(rec); } time_type simulation::run(const units::quantity& tfinal, const units::quantity& dt) { auto dt_ms = dt.value_as(units::ms); diff --git a/arbor/threading/threading.cpp b/arbor/threading/threading.cpp index 9b7c639ca8..44166d3396 100644 --- a/arbor/threading/threading.cpp +++ b/arbor/threading/threading.cpp @@ -1,8 +1,10 @@ #include +#include #include #include #include +#include #include "threading/threading.hpp" #include "affinity.hpp" @@ -184,3 +186,9 @@ void task_system::async(priority_task ptsk) { std::unordered_map task_system::get_thread_ids() const { return thread_ids_; }; + +std::optional task_system::get_current_thread_id() const { + const auto it = thread_ids_.find(std::this_thread::get_id()); + if(it != thread_ids_.end()) return it->second; + return std::nullopt; +} diff --git a/arbor/threading/threading.hpp b/arbor/threading/threading.hpp index c690781e89..7e2e4a2509 100644 --- a/arbor/threading/threading.hpp +++ b/arbor/threading/threading.hpp @@ -8,11 +8,12 @@ #include #include #include +#include #include #include -#include #include #include +#include #include @@ -220,6 +221,9 @@ class ARB_ARBOR_API task_system { // Returns the thread_id map std::unordered_map get_thread_ids() const; + + // Returns the calling thread id if part of the task system + std::optional get_current_thread_id() const; }; class task_group { diff --git a/arbor/util/spatial_tree.hpp b/arbor/util/spatial_tree.hpp new file mode 100644 index 0000000000..f572781f87 --- /dev/null +++ b/arbor/util/spatial_tree.hpp @@ -0,0 +1,190 @@ +#pragma once + +#include "util/visit_variant.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace arb { + +// An immutable spatial data structure for storing and iterating over data in "DIM" dimensional +// space. If DIM = 1 it's a binary tree, if DIM = 2 it's a quad tree and so on. +template +class spatial_tree { +public: + static_assert(DIM >= 1, "Dimension of tree must be at least 1."); + + using value_type = T; + using point_type = std::array; + using node_data = std::vector; + using leaf_data = std::vector; + using location_func_type = point_type (*)(const T &); + + spatial_tree(): size_(0), data_(leaf_data()) {} + + // Create a tree of given maximum depth and target leaf size. If any leaf holds more than the + // target size, it is recursively split into up to 2^DIM nodes until reaching the maximum depth. + // The "location" function type must have signature (const T&) -> point_type. + spatial_tree(std::size_t max_depth, + std::size_t leaf_size_target, + leaf_data data, + location_func_type location): + size_(data.size()), + data_(std::move(data)), + location_(location) { + auto &leaf_d = std::get(data_); + if (leaf_d.empty()) return; + + min_.fill(std::numeric_limits::max()); + max_.fill(std::numeric_limits::lowest()); + + for (const auto &d: leaf_d) { + const auto p = location(d); + for (std::size_t i = 0; i < DIM; ++i) { + if (p[i] < min_[i]) min_[i] = p[i]; + if (p[i] > max_[i]) max_[i] = p[i]; + } + } + + point_type mid; + for (std::size_t i = 0; i < DIM; ++i) { mid[i] = (max_[i] - min_[i]) / 2.0 + min_[i]; } + + if (max_depth > 1 && leaf_d.size() > leaf_size_target) { + constexpr auto divisor = math::pow(2, DIM); + + // The initial index of the sub node containing p + auto sub_node_index = [&](const point_type &p) { + std::size_t index = 0; + for (std::size_t i = 0; i < DIM; ++i) { index += i * 2 * (p[i] >= mid[i]); } + return index; + }; + + node_data new_nodes; + new_nodes.reserve(divisor); + + // assign each point to sub-node + std::array new_leaf_data; + for (const auto &d: leaf_d) { + const auto p = location(d); + new_leaf_data[sub_node_index(p)].emplace_back(d); + } + + // move data into new sub-nodes if not empty + for (auto &l_d: new_leaf_data) { + if (l_d.size()) + new_nodes.emplace_back(max_depth - 1, leaf_size_target, std::move(l_d), location); + } + + // replace current data_ with new sub-nodes + this->data_ = std::move(new_nodes); + } + } + + spatial_tree(const spatial_tree &) = default; + + spatial_tree(spatial_tree &&t) noexcept(std::is_nothrow_move_assignable_v) { + *this = std::move(t); + } + + spatial_tree &operator=(const spatial_tree &) = default; + + spatial_tree &operator=(spatial_tree &&t) noexcept( + std::is_nothrow_default_constructible_v && + std::is_nothrow_move_assignable_v && + std::is_nothrow_default_constructible_v && + std::is_nothrow_move_assignable_vdata_)>) { + + data_ = std::move(t.data_); + size_ = t.size_; + min_ = t.min_; + max_ = t.max_; + location_ = t.location_; + + t.data_ = leaf_data(); + t.size_ = 0; + t.min_ = point_type(); + t.max_ = point_type(); + t.location_ = nullptr; + + return *this; + } + + // Iterate over all points recursively. + // func must have signature `void func(const T&)`. + template + inline void for_each(const F &func) const { + util::visit_variant( + data_, + [&](const node_data &data) { + for (const auto &node: data) { node.for_each(func); } + }, + [&](const leaf_data &data) { + for (const auto &d: data) { func(d); } + }); + } + + // Iterate over all points within the given bounding box recursively. + // func must have signature `void func(const T&)`. + template + inline void bounding_box_for_each(const point_type &box_min, + const point_type &box_max, + const F &func) const { + auto all_smaller_eq = [](const point_type &lhs, const point_type &rhs) { + bool result = true; + for (std::size_t i = 0; i < DIM; ++i) { result &= lhs[i] <= rhs[i]; } + return result; + }; + + util::visit_variant( + data_, + [&](const node_data &data) { + if (all_smaller_eq(box_min, min_) && all_smaller_eq(max_, box_max)) { + // sub-nodes fully inside box -> call without further boundary + // checks + for (const auto &node: data) { node.template for_each(func); } + } + else { + // sub-nodes partially overlap bounding box + for (const auto &node: data) { + if (all_smaller_eq(node.min_, box_max) && + all_smaller_eq(box_min, node.max_)) + node.template bounding_box_for_each(box_min, box_max, func); + } + } + }, + [&](const leaf_data &data) { + if (all_smaller_eq(box_min, min_) && all_smaller_eq(max_, box_max)) { + // sub-nodes fully inside box -> call without further boundary + // checks + for (const auto &d: data) { func(d); } + } + else { + // sub-nodes partially overlap bounding box + for (const auto &d: data) { + const auto p = location_(d); + if (all_smaller_eq(p, box_max) && all_smaller_eq(box_min, p)) { func(d); } + } + } + }); + + } + + inline std::size_t size() const noexcept { return size_; } + + inline bool empty() const noexcept { return !size_; } + +private: + std::size_t size_; + point_type min_, max_; + std::variant data_; + location_func_type location_; +}; + +} // namespace arb diff --git a/arbor/util/visit_variant.hpp b/arbor/util/visit_variant.hpp new file mode 100644 index 0000000000..bf6e32ef16 --- /dev/null +++ b/arbor/util/visit_variant.hpp @@ -0,0 +1,41 @@ +#pragma once + + +#include +#include + + +namespace arb { +namespace util { + +namespace impl { +template +inline void visit_variant_impl(VARIANT &&v, F &&f) { + constexpr auto index = std::variant_size_v> - 1; + if (v.index() == index) f(std::get(v)); +} + +template +inline void visit_variant_impl(VARIANT &&v, F &&f, FUNCS &&...functions) { + constexpr auto index = + std::variant_size_v> - sizeof...(FUNCS) - 1; + if (v.index() == index) f(std::get(v)); + visit_variant_impl(std::forward(v), std::forward(functions)...); +} +} // namespace impl + +/* + * Similar to std::visit, call contained type with matching function. Expects a function for each + * type in variant and in the same order. More performant than std::visit through the use of + * indexing instead of function tables. + */ +template +inline void visit_variant(VARIANT &&v, FUNCS &&...functions) { + static_assert(std::variant_size_v> == + sizeof...(FUNCS), + "The first argument must be of type std::variant and the " + "number of functions must match the variant size."); + impl::visit_variant_impl(std::forward(v), std::forward(functions)...); +} +} // namespace util +} // namespace arb diff --git a/arborio/CMakeLists.txt b/arborio/CMakeLists.txt index bd4bbce4a4..5a4bae4254 100644 --- a/arborio/CMakeLists.txt +++ b/arborio/CMakeLists.txt @@ -6,6 +6,7 @@ set(arborio-sources cv_policy_parse.cpp label_parse.cpp neuroml.cpp + networkio.cpp nml_parse_morphology.cpp) add_library(arborio ${arborio-sources}) diff --git a/arborio/include/arborio/networkio.hpp b/arborio/include/arborio/networkio.hpp new file mode 100644 index 0000000000..c57f0d98a5 --- /dev/null +++ b/arborio/include/arborio/networkio.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include + +namespace arborio { + +struct ARB_SYMBOL_VISIBLE network_parse_error: arb::arbor_exception { + explicit network_parse_error(const std::string& msg, const arb::src_location& loc); + explicit network_parse_error(const std::string& msg): arb::arbor_exception(msg) {} +}; + +template +using parse_network_hopefully = arb::util::expected; + +ARB_ARBORIO_API parse_network_hopefully parse_network_selection_expression( + const std::string& s); +ARB_ARBORIO_API parse_network_hopefully parse_network_value_expression( + const std::string& s); + +namespace literals { +inline arb::network_selection operator"" _ns(const char* s, std::size_t) { + if (auto r = parse_network_selection_expression(s)) + return *r; + else + throw r.error(); +} + +inline arb::network_value operator"" _nv(const char* s, std::size_t) { + if (auto r = parse_network_value_expression(s)) + return *r; + else + throw r.error(); +} + +} // namespace literals +} // namespace arborio diff --git a/arborio/networkio.cpp b/arborio/networkio.cpp new file mode 100644 index 0000000000..194ff41c9e --- /dev/null +++ b/arborio/networkio.cpp @@ -0,0 +1,369 @@ +#include +#include +#include + +#include + +#include +#include +#include + +#include "parse_helpers.hpp" + +namespace arborio { + +network_parse_error::network_parse_error(const std::string& msg, const arb::src_location& loc): + arb::arbor_exception( + concat("error in label description: ", msg, " at :", loc.line, ":", loc.column)) {} + +namespace { +using eval_map_type = std::unordered_multimap; + +eval_map_type network_eval_map{ + {"gid-range", + make_call([](int begin, int end) { return arb::gid_range(begin, end); }, + "Gid range [begin, end) with step size 1: ((begin:integer) (end:integer))")}, + {"gid-range", + make_call( + [](int begin, int end, int step) { return arb::gid_range(begin, end, step); }, + "Gid range [begin, end) with step size: ((begin:integer) (end:integer) " + "(step:integer))")}, + + // cell kind + {"cable-cell", make_call<>([]() { return arb::cell_kind::cable; }, "Cable cell kind")}, + {"lif-cell", make_call<>([]() { return arb::cell_kind::lif; }, "Lif cell kind")}, + {"benchmark-cell", + make_call<>([]() { return arb::cell_kind::benchmark; }, "Benchmark cell kind")}, + {"spike-source-cell", + make_call<>([]() { return arb::cell_kind::spike_source; }, "Spike source cell kind")}, + + // network_selection + {"all", make_call<>(arb::network_selection::all, "network selection of all cells and labels")}, + {"none", make_call<>(arb::network_selection::none, "network selection of no cells and labels")}, + {"inter-cell", + make_call<>(arb::network_selection::inter_cell, + "network selection of inter-cell connections only")}, + {"network-selection", + make_call(arb::network_selection::named, + "network selection with 1 argument: (value:string)")}, + {"intersect", + make_conversion_fold(arb::network_selection::intersect, + "intersection of network selections with at least 2 arguments: " + "(network_selection network_selection [...network_selection])")}, + {"join", + make_conversion_fold(arb::network_selection::join, + "join or union operation of network selections with at least 2 arguments: " + "(network_selection network_selection [...network_selection])")}, + {"symmetric-difference", + make_conversion_fold(arb::network_selection::symmetric_difference, + "symmetric difference operation between network selections with at least 2 arguments: " + "(network_selection network_selection [...network_selection])")}, + {"difference", + make_call( + arb::network_selection::difference, + "difference of first selection with the second one: " + "(network_selection network_selection)")}, + {"complement", + make_call(arb::network_selection::complement, + "complement of given selection argument: (network_selection)")}, + {"source-cell-kind", + make_call(arb::network_selection::source_cell_kind, + "all sources of cells matching given cell kind argument: (kind:cell-kind)")}, + {"target-cell-kind", + make_call(arb::network_selection::target_cell_kind, + "all targets of cells matching given cell kind argument: (kind:cell-kind)")}, + {"source-label", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector labels; + std::transform( + vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::source_label(std::move(labels)); + }, + "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"target-label", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector labels; + std::transform( + vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::target_label(std::move(labels)); + }, + "all targets in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"source-cell", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector gids; + std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::source_cell(std::move(gids)); + }, + "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"source-cell", + make_call(static_cast( + arb::network_selection::source_cell), + "all sources in cell with gid range: (range:gid-range)")}, + {"target-cell", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector gids; + std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::target_cell(std::move(gids)); + }, + "all targets in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"target-cell", + make_call(static_cast( + arb::network_selection::target_cell), + "all targets in cell with gid range: " + "(range:gid-range)")}, + {"chain", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector gids; + std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::chain(std::move(gids)); + }, + "A chain of connections in the given order of gids in list, such that entry \"i\" is " + "the source and entry \"i+1\" the target: (gid:integer) [...(gid:integer)]")}, + {"chain", + make_call( + static_cast(arb::network_selection::chain), + "A chain of connections for all gids in range [begin, end) with given step size. Each " + "entry \"i\" is connected as source to the target \"i+1\": (begin:integer) " + "(end:integer) (step:integer)")}, + {"chain-reverse", + make_call(arb::network_selection::chain_reverse, + "A chain of connections for all gids in range [begin, end) with given step size. Each " + "entry \"i+1\" is connected as source to the target \"i\". This results in " + "connection directions in reverse compared to the (chain-range ...) selection: " + "(begin:integer) " + "(end:integer) (step:integer)")}, + {"random", + make_call(arb::network_selection::random, + "randomly selected with given seed and probability. 2 arguments: (seed:integer, " + "p:real)")}, + {"random", + make_call(arb::network_selection::random, + "randomly selected with given seed and probability function. Any probability value is " + "clamped to [0.0, 1.0]. 2 arguments: (seed:integer, " + "p:network-value)")}, + {"distance-lt", + make_call(arb::network_selection::distance_lt, + "Select if distance between source and target is less than given distance in " + "micro meter: (distance:real)")}, + {"distance-gt", + make_call(arb::network_selection::distance_gt, + "Select if distance between source and target is greater than given distance in " + "micro meter: (distance:real)")}, + + // network_value + {"scalar", + make_call(arb::network_value::scalar, + "A fixed scalar value. 1 argument: (value:real)")}, + {"network-value", + make_call(arb::network_value::named, + "A named network value with 1 argument: (value:string)")}, + {"distance", + make_call(arb::network_value::distance, + "Distance between source and target scaled by given value with unit [1/um]. 1 " + "argument: (scale:real)")}, + {"distance", + make_call<>([]() { return arb::network_value::distance(1.0); }, + "Distance between source and target scaled by 1.0 with unit [1/um].")}, + {"uniform-distribution", + make_call( + [](unsigned seed, double begin, double end) { + return arb::network_value::uniform_distribution(seed, {begin, end}); + }, + "Uniform random distribution within interval [begin, end): (seed:integer, begin:real, " + "end:real)")}, + {"normal-distribution", + make_call(arb::network_value::normal_distribution, + "Normal random distribution with given mean and standard deviation: (seed:integer, " + "mean:real, std_deviation:real)")}, + {"truncated-normal-distribution", + make_call( + [](unsigned seed, double mean, double std_deviation, double begin, double end) { + return arb::network_value::truncated_normal_distribution( + seed, mean, std_deviation, {begin, end}); + }, + "Truncated normal random distribution with given mean and standard deviation within " + "interval [begin, end]: (seed:integer, mean:real, std_deviation:real, begin:real, " + "end:real)")}, + {"if-else", + make_call(arb::network_value::if_else, + "Returns the first network-value if a connection is the given network-selection and " + "the second network-value otherwise. 3 arguments: (sel:network-selection, " + "true_value:network-value, false_value:network_value)")}, + {"add", + make_conversion_fold( + arb::network_value::add, + "Sum of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"sub", + make_conversion_fold( + arb::network_value::sub, + "Subtraction of network values from the first argument with at least 2 arguments: " + "((network-value | double) (network-value | double) [...(network-value | double)])")}, + {"mul", + make_conversion_fold( + arb::network_value::mul, + "Multiplication of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"div", + make_conversion_fold( + arb::network_value::div, + "Division of the first argument by each following network value sequentially with at " + "least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"min", + make_conversion_fold( + arb::network_value::min, + "Minimum of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"max", + make_conversion_fold( + arb::network_value::max, + "Minimum of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"log", make_call(arb::network_value::log, "Logarithm. 1 argument: (value:real)")}, + {"log", + make_call(arb::network_value::log, "Logarithm. 1 argument: (value:real)")}, + {"exp", + make_call(arb::network_value::exp, + "Exponential function. 1 argument: (value:real)")}, + {"exp", + make_call(arb::network_value::exp, + "Exponential function. 1 argument: (value:real)")}, +}; + +parse_network_hopefully eval(const s_expr& e, const eval_map_type& map); + +parse_network_hopefully> eval_args(const s_expr& e, + const eval_map_type& map) { + if (!e) return {std::vector{}}; // empty argument list + std::vector args; + for (auto& h: e) { + if (auto arg = eval(h, map)) { args.push_back(std::move(*arg)); } + else { return util::unexpected(std::move(arg.error())); } + } + return args; +} + +// Generate a string description of a function evaluation of the form: +// Example output: +// 'foo' with 1 argument: (real) +// 'bar' with 0 arguments +// 'cat' with 3 arguments: (locset region integer) +// Where 'foo', 'bar' and 'cat' are the name of the function, and the +// types (integer, real, region, locset) are inferred from the arguments. +std::string eval_description(const char* name, const std::vector& args) { + auto type_string = [](const std::type_info& t) -> const char* { + if (t == typeid(int)) return "integer"; + if (t == typeid(double)) return "real"; + if (t == typeid(arb::region)) return "region"; + if (t == typeid(arb::locset)) return "locset"; + if (t == typeid(arb::network_selection)) return "network_selection"; + if (t == typeid(arb::network_value)) return "network_value"; + return "unknown"; + }; + + const auto nargs = args.size(); + std::string msg = concat("'", name, "' with ", nargs, " argument", nargs != 1u ? "s:" : ":"); + if (nargs) { + msg += " ("; + bool first = true; + for (auto& a: args) { + msg += concat(first ? "" : " ", type_string(a.type())); + first = false; + } + msg += ")"; + } + return msg; +} + +// Evaluate an s expression. +// On success the result is wrapped in std::any, where the result is one of: +// int : an integer atom +// double : a real atom +// std::string : a string atom: to be treated as a label +// arb::region : a region +// arb::locset : a locset +// +// If there invalid input is detected, hopefully return value contains +// a label_error_state with an error string and location. +// +// If there was an unexpected/fatal error, an exception will be thrown. +parse_network_hopefully eval(const s_expr& e, const eval_map_type& map) { + if (e.is_atom()) { return eval_atom(e); } + if (e.head().is_atom()) { + // This must be a function evaluation, where head is the function name, and + // tail is a list of arguments. + + // Evaluate the arguments, and return error state if an error occurred. + auto args = eval_args(e.tail(), map); + if (!args) { return util::unexpected(args.error()); } + + // Find all candidate functions that match the name of the function. + auto& name = e.head().atom().spelling; + auto matches = map.equal_range(name); + // Search for a candidate that matches the argument list. + for (auto i = matches.first; i != matches.second; ++i) { + if (i->second.match_args(*args)) { // found a match: evaluate and return. + return i->second.eval(*args); + } + } + + // Unable to find a match: try to return a helpful error message. + const auto nc = std::distance(matches.first, matches.second); + auto msg = concat("No matches for ", + eval_description(name.c_str(), *args), + "\n There are ", + nc, + " potential candidates", + nc ? ":" : "."); + int count = 0; + for (auto i = matches.first; i != matches.second; ++i) { + msg += concat("\n Candidate ", ++count, " ", i->second.message); + } + return util::unexpected(network_parse_error(msg, location(e))); + } + + return util::unexpected(network_parse_error( + concat("'", e, "' is not either integer, real expression of the form (op )"), + location(e))); +} + +} // namespace + +parse_network_hopefully parse_network_selection_expression( + const std::string& s) { + if (auto e = eval(parse_s_expr(s), network_eval_map)) { + if (e->type() == typeid(arb::network_selection)) { + return {std::move(std::any_cast(*e))}; + } + return util::unexpected(network_parse_error(concat("Invalid iexpr description: '", s))); + } + else { return util::unexpected(network_parse_error(std::string() + e.error().what())); } +} + +parse_network_hopefully parse_network_value_expression(const std::string& s) { + if (auto e = eval(parse_s_expr(s), network_eval_map)) { + if (e->type() == typeid(arb::network_value)) { + return {std::move(std::any_cast(*e))}; + } + return util::unexpected(network_parse_error(concat("Invalid iexpr description: '", s))); + } + else { return util::unexpected(network_parse_error(std::string() + e.error().what())); } +} + +} // namespace arborio diff --git a/doc/concepts/interconnectivity.rst b/doc/concepts/interconnectivity.rst index 8b7de20c46..bdb93f3a29 100644 --- a/doc/concepts/interconnectivity.rst +++ b/doc/concepts/interconnectivity.rst @@ -18,6 +18,272 @@ These sites as such are not connected yet, however the :ref:`recipe `, arbor supports high-level description of a cell network. It is based around a ``network_selection`` type, that represents a selection from the set of all possible connections between cells. A selection can be created based on different criteria, such as source or target label, cell indices and also distance between source and target. Selections can then be combined with other selections through set algebra like expressions. For distance calculations, the location of each connection point on the cell is resolved through the morphology combined with a cell isometry, which describes translation and rotation of the cell. +Each connection also requires a weight and delay value. For this purpose, a ``network_value`` type is available, that allows to mathematically describe the value calculation using common math functions, as well random distributions. + +The following example shows the relevant recipe functions, where cells are connected into a ring with additional random connections between them: + +.. code-block:: python + + def network_description(self): + seed = 42 + + # create a chain + chain = f"(chain (gid-range 0 {self.ncells}))" + # connect front and back of chain to form ring + ring = f"(join {chain} (intersect (source-cell {self.ncells - 1}) (target-cell 0)))" + + # Create random connections with probability inversely proportional to the distance within a + # radius + max_dist = 400.0 # μm + probability = f"(div (sub {max_dist} (distance)) {max_dist})" + rand = f"(intersect (random {seed} {probability}) (distance-lt {max_dist}))" + + # combine ring with random selection + s = f"(join {ring} {rand})" + # restrict to inter-cell connections and certain source / target labels + s = f'(intersect {s} (inter-cell) (source-label "detector") (target-label "syn"))' + + # fixed weight for connections in ring + w_ring = "(scalar 0.01)" + # random normal distributed weight with mean 0.02 μS, standard deviation 0.01 μS + # and truncated to [0.005, 0.035] + w_rand = f"(truncated-normal-distribution {seed} 0.02 0.01 0.005 0.035)" + + # combine into single weight expression + w = f"(if-else {ring} {w_ring} {w_rand})" + + # fixed delay + d = "(scalar 5.0)" # ms delay + + return arbor.network_description(s, w, d, {}) + + def cell_isometry(self, gid): + # place cells with equal distance on a circle + radius = 500.0 # μm + angle = 2.0 * math.pi * gid / self.ncells + return arbor.isometry.translate(radius * math.cos(angle), radius * math.sin(angle), 0) + + +The export function ``generate_network_connections`` allows the inspection of generated connections. The exported connections include the cell index, local label and location of both source and target. + + +.. note:: + + Expressions using distance require a cell isometry to resolve the global location of connection points. + +.. note:: + + A high-level description may be used together with providing explicit connection lists for each cell, but it is up to the user to avoid multiple connections between the same source and target. + +.. warning:: + + Generating connections always involves additional work and may increase the time spent in the simulation initialization phase. + + +.. _interconnectivity-selection-expressions: + +Network Selection Expressions +----------------------------- + +.. label:: (gid-range begin:integer end:integer) + + A range expression, representing a range of indices in the half-open interval [begin, end). + +.. label:: (gid-range begin:integer end:integer step:integer) + + A range expression, representing a range of indices in the half-open interval [begin, end) with a given step size. Step size must be positive. + +.. label:: (cable-cell) + + Cell kind expression for cable cells. + +.. label:: (lif-cell) + + Cell kind expression for lif cells. + +.. label:: (benchmark-cell) + + Cell kind expression for benchmark cells. + +.. label:: (spike-source-cell) + + Cell kind expression for spike source cells. + +.. label:: (all) + + A selection of all possible connections. + +.. label:: (none) + + A selection representing the empty set of possible connections. + +.. label:: (inter-cell) + + A selection of all connections that connect two different cells. + +.. label:: (network-selection name:string) + + A named selection within the network dictionary. + +.. label:: (intersect network-selection network-selection [...network-selection]) + + The intersection of at least two selections. + +.. label:: (join network-selection network-selection [...network-selection]) + + The union of at least two selections. + +.. label:: (symmetric-difference network-selection network-selection [...network-selection]) + + The symmetric difference of at least two selections. + +.. label:: (difference network-selection network-selection) + + The difference of two selections. + +.. label:: (difference network-selection) + + The complement or opposite of the given selection. + +.. label:: (source-cell-kind kind:cell-kind) + + All connections, where the source cell is of the given type. + +.. label:: (target-cell-kind kind:cell-kind) + + All connections, where the target cell is of the given type. + +.. label:: (source-label label:string) + + All connections, where the source label matches the given label. + +.. label:: (target-label label:string) + + All connections, where the target label matches the given label. + +.. label:: (source-cell integer [...integer]) + + All connections, where the source cell index matches one of the given integer values. + +.. label:: (source-cell range:gid-range) + + All connections, where the source cell index is contained in the given gid-range. + +.. label:: (target-cell integer [...integer]) + + All connections, where the target cell index matches one of the given integer values. + +.. label:: (target-cell range:gid-range) + + All connections, where the target cell index is contained in the given gid-range. + +.. label:: (chain integer [...integer]) + + A chain of connections between cells in the given order of in the list, such that entry "i" is the source and entry "i+1" the target. + +.. label:: (chain range:gid-range) + + A chain of connections between cells in the given order of the gid-range, such that entry "i" is the source and entry "i+1" the target. + +.. label:: (chain-reverse range:gid-range) + + A chain of connections between cells in reverse of the given order of the gid-range, such that entry "i+1" is the source and entry "i" the target. + +.. label:: (random seed:integer p:real) + + A random selection of connections, where each connection is selected with the given probability. + +.. label:: (random seed:integer p:network-value) + + A random selection of connections, where each connection is selected with the given probability expression. + +.. label:: (distance-lt dist:real) + + All connections, where the distance between source and target is less than the given value in micro meter. + +.. label:: (distance-gt dist:real) + + All connections, where the distance between source and target is greater than the given value in micro meter. + + +.. _interconnectivity-value-expressions: + +Network Value Expressions +------------------------- + +.. label:: (scalar value:real) + + A scalar of given value. + +.. label:: (network-value name:string) + + A named network value in the network dictionary. + +.. label:: (distance) + + The distance between source and target. + +.. label:: (distance value:real) + + The distance between source and target scaled by the given value. + +.. label:: (uniform-distribution seed:integer begin:real end:real) + + Uniform random distribution within the interval [begin, end). + +.. label:: (normal-distribution seed:integer mean:real std_deviation:real) + + Normal random distribution with given mean and standard deviation. + +.. label:: (truncated-normal-distribution seed:integer mean:real std_deviation:real begin:real end:real) + + Truncated normal random distribution with given mean and standard deviation within the interval [begin, end). + +.. label:: (if-else sel:network-selection true_value:network-value false_value:network-value) + + Truncated normal random distribution with given mean and standard deviation within the interval [begin, end). + +.. label:: (add (network-value | real) (network-value | real) [... (network-value | real)]) + + Addition of at least two network values or real numbers. + +.. label:: (sub (network-value | real) (network-value | real) [... (network-value | real)]) + + Subtraction of at least two network values or real numbers. + +.. label:: (mul (network-value | real) (network-value | real) [... (network-value | real)]) + + Multiplication of at least two network values or real numbers. + +.. label:: (div (network-value | real) (network-value | real) [... (network-value | real)]) + + Division of at least two network values or real numbers. + The expression is evaluated from the left to right, dividing the first element by each divisor in turn. + +.. label:: (min (network-value | real) (network-value | real) [... (network-value | real)]) + + Minimum of at least two network values or real numbers. + +.. label:: (max (network-value | real) (network-value | real) [... (network-value | real)]) + + Maximum of at least two network values or real numbers. + +.. label:: (log (network-value | real)) + + Logarithm of a network value or real number. + +.. label:: (exp (network-value | real)) + + Exponential function of a network value or real number. + + + .. _interconnectivity-mut: Mutability @@ -37,8 +303,8 @@ connection table outside calls to `run`, for example # extend the recipe to more connections rec.add_connections() - # use `connections_on` to build a new connection table - sim.update_connections(rec) + # use updated recipe to build a new connection table + sim.update(rec) # run simulation for 0.25ms with the extended connectivity sim.run(0.5, 0.025) @@ -48,12 +314,6 @@ must be explicitly included in the updated callback. This can also be used to update connection weights and delays. Note, however, that there is currently no way to introduce new sites to the simulation, nor any changes to gap junctions. -The ``update_connections`` method accepts either a full ``recipe`` (but will -**only** use the ``connections_on`` and ``events_generators`` callbacks) or a -``connectivity``, which is a reduced recipe exposing only the relevant callbacks. -Currently ``connectivity`` is only available in C++; Python users have to pass a -full recipe. - .. warning:: The semantics of connection updates are subtle and might produce surprising @@ -78,6 +338,8 @@ full recipe. in these callbacks. This is doubly important when using models with dynamic connectivity where the temptation to store all connections is even larger and each call to ``update`` will re-evaluate the corresponding callbacks. + Alternatively, connections can be generated by Arbor using the network DSL + through the ``network_description`` callback function. .. _interconnectivitycross: diff --git a/doc/cpp/interconnectivity.rst b/doc/cpp/interconnectivity.rst index 9bd2bc49a9..bfd9557921 100644 --- a/doc/cpp/interconnectivity.rst +++ b/doc/cpp/interconnectivity.rst @@ -8,11 +8,11 @@ Interconnectivity .. cpp:class:: cell_connection Describes a connection between two cells: a pre-synaptic source and a - post-synaptic destination. The source is typically a threshold detector on - a cell or a spike source. The destination is a synapse on the post-synaptic cell. + post-synaptic target. The source is typically a threshold detector on + a cell or a spike source. The target is a synapse on the post-synaptic cell. - The :cpp:member:`dest` does not include the gid of a cell, this is because a - :cpp:class:`cell_connection` is bound to the destination cell which means that the gid + The :cpp:member:`target` does not include the gid of a cell, this is because a + :cpp:class:`cell_connection` is bound to the target cell which means that the gid is implicitly known. .. cpp:member:: cell_global_label_type source @@ -20,7 +20,7 @@ Interconnectivity Source end point, represented by a :cpp:type:`cell_global_label_type` which packages a cell gid, label of a group of sources on the cell, and source selection policy. - .. cpp:member:: cell_local_label_type dest + .. cpp:member:: cell_local_label_type target Destination end point on the cell, represented by a :cpp:type:`cell_local_label_type` which packages a label of a group of targets on the cell and a selection policy. @@ -41,11 +41,11 @@ Interconnectivity .. cpp:class:: ext_cell_connection Describes a connection between two cells: a pre-synaptic source and a - post-synaptic destination. The source is typically a threshold detector on - a cell or a spike source. The destination is a synapse on the post-synaptic cell. + post-synaptic target. The source is typically a threshold detector on + a cell or a spike source. The target is a synapse on the post-synaptic cell. - The :cpp:member:`dest` does not include the gid of a cell, this is because a - :cpp:class:`ext_cell_connection` is bound to the destination cell which means that the gid + The :cpp:member:`target` does not include the gid of a cell, this is because a + :cpp:class:`ext_cell_connection` is bound to the target cell which means that the gid is implicitly known. .. cpp:member:: cell_remote_label_type source @@ -53,7 +53,7 @@ Interconnectivity Source end point, represented by a :cpp:type:`cell_remote_label_type` which packages a cell gid, integral tag of a group of sources on the cell, and source selection policy. - .. cpp:member:: cell_local_label_type dest + .. cpp:member:: cell_local_label_type target Destination end point on the cell, represented by a :cpp:type:`cell_local_label_type` which packages a label of a group of targets on the cell and a selection policy. @@ -99,3 +99,276 @@ Interconnectivity .. cpp:member:: float weight unit-less gap junction connection weight. + +.. cpp:class:: network_site_info + + A network connection site on a cell. Used for generated connections through the high-level network description. + + .. cpp:member:: cell_gid_type gid + + The cell index. + + .. cpp:member:: cell_kind kind + + The cell kind. + + .. cpp:member:: cell_tag_type label + + The associated label. + + .. cpp:member:: mlocation location + + The local location on the cell. + + .. cpp:member:: mpoint global_location + + The global location in cartesian coordinates. + + +.. cpp:class:: network_connection_info + + A network connection between cells. Used for generated connections through the high-level network description. + + .. cpp:member:: network_site_info source + + The source connection site. + + .. cpp:member:: network_site_info target + + The target connection site. + + +.. cpp:class:: network_value + + A network value, describing the its calculation for each connection. + + .. cpp:function:: network_value scalar(double value) + + A fixed scalar valaue. + + .. cpp:function:: network_value named(std::string name) + + A named network value in the network label dictionary. + + .. cpp:function:: network_value distance() + + The value representing the distance between source and target. + + .. cpp:function:: network_value uniform_distribution(unsigned seed, const std::array& range) + + A uniform random distribution within [range_0, range_1) + + .. cpp:function:: network_value normal_distribution(unsigned seed, double mean, double std_deviation) + + A normal random distribution with given mean and standard deviation. + + .. cpp:function:: network_value truncated_normal_distribution(unsigned seed, double mean, double std_deviation, const std::array& range) + + A truncated normal random distribution with given mean and standard deviation. Sampled through accept-reject method to only returns values in [range_0, range_1) + + .. cpp:function:: network_value custom(custom_func_type func) + + Custom value using the provided function "func". Repeated calls with the same arguments to "func" must yield the same result. + + .. cpp:function:: network_value add(network_value left, network_value right) + + Summation of two values. + + .. cpp:function:: network_value sub(network_value left, network_value right) + + Subtraction of two values. + + .. cpp:function:: network_value mul(network_value left, network_value right) + + Multiplication of two values. + + .. cpp:function:: network_value div(network_value left, network_value right) + + Division of two values. + + .. cpp:function:: network_value min(network_value left, network_value right) + + Minimum of two values. + + .. cpp:function:: network_value max(network_value left, network_value right) + + Maximum of two values. + + .. cpp:function:: network_value exp(network_value v) + + Exponential of given value. + + .. cpp:function:: network_value log(network_value v) + + Logarithm of given value. + + .. cpp:function:: if_else(network_selection cond, network_value true_value, network_value false_value) + + if contained in selection, the true_value is used and the false_value otherwise. + + +.. cpp:class:: network_selection + + A network selection, describing a subset of all possible connections. + + .. cpp:function:: network_selection all() + + Select all + + .. cpp:function:: network_selection none(); + + Select none + + .. cpp:function:: network_selection named(std::string name); + + Named selection in the network label dictionary + + .. cpp:function:: network_selection inter_cell(); + + Only select connections between different cells + + .. cpp:function:: network_selection source_cell_kind(cell_kind kind); + + Select connections with the given source cell kind + + .. cpp:function:: network_selection target_cell_kind(cell_kind kind); + + Select connections with the given target cell kind + + .. cpp:function:: network_selection source_label(std::vector labels); + + Select connections with the given source label + + .. cpp:function:: network_selection target_label(std::vector labels); + + Select connections with the given target label + + .. cpp:function:: network_selection source_cell(std::vector gids); + + Select connections with source cells matching the indices in the list + + .. cpp:function:: network_selection source_cell(gid_range range); + + Select connections with source cells matching the indices in the range + + .. cpp:function:: network_selection target_cell(std::vector gids); + + Select connections with target cells matching the indices in the list + + .. cpp:function:: network_selection target_cell(gid_range range); + + Select connections with target cells matching the indices in the range + + .. cpp:function:: network_selection chain(std::vector gids); + + Select connections that form a chain, such that source cell "i" is connected to the target cell "i+1" + + .. cpp:function:: network_selection chain(gid_range range); + + Select connections that form a chain, such that source cell "i" is connected to the target cell "i+1" + + .. cpp:function:: network_selection chain_reverse(gid_range range); + + Select connections that form a reversed chain, such that source cell "i+1" is connected to the target cell "i" + + .. cpp:function:: network_selection intersect(network_selection left, network_selection right); + + Select connections, that are selected by both "left" and "right" + + .. cpp:function:: network_selection join(network_selection left, network_selection right); + + Select connections, that are selected by either or both "left" and "right" + + .. cpp:function:: network_selection difference(network_selection left, network_selection right); + + Select connections, that are selected by "left", unless selected by "right" + + .. cpp:function:: network_selection symmetric_difference(network_selection left, network_selection right); + + Select connections, that are selected by "left" or "right", but not both + + .. cpp:function:: network_selection complement(network_selection s); + + Invert the selection + + .. cpp:function:: network_selection random(unsigned seed, network_value p); + + Random selection using the bernoulli random distribution with probability "p" between 0.0 and 1.0 + + .. cpp:function:: network_selection custom(custom_func_type func); + + Custom selection using the provided function "func". Repeated calls with the same arguments + to "func" must yield the same result. For gap junction selection, + "func" must be symmetric (func(a,b) = func(b,a)). + + .. cpp:function:: network_selection distance_lt(double d); + + Only select within given distance. This may enable more efficient sampling through an + internal spatial data structure. + + .. cpp:function:: network_selection distance_gt(double d); + + Only select if distance greater then given distance. This may enable more efficient sampling + through an internal spatial data structure. + + +.. cpp:class:: network_label_dict + + Dictionary storing named network values and selections. + + .. cpp:function:: network_label_dict& set(const std::string& name, network_selection s) + + Store a network selection under the given name + + .. cpp:function:: network_label_dict& set(const std::string& name, network_value v) + + Store a network value under the given name + + .. cpp:function:: std::optional selection(const std::string& name) const + + Returns the stored network selection of the given name if it exists. None otherwise. + + .. cpp:function:: std::optional value(const std::string& name) const + + Returns the stored network value of the given name if it exists. None otherwise. + + .. cpp:function:: const ns_map& selections() const + + All stored network selections + + .. cpp:function:: const nv_map& selections() const + + All stored network values + + +.. cpp:class:: network_description + + A complete network description required for processing. + + .. cpp:member:: network_selection selection + + Selection of connections. + + .. cpp:member:: network_value weight + + Weight of generated connections. + + .. cpp:member:: network_value delay + + Delay of generated connections. + + .. cpp:member:: network_label_dict dict + + Label dictionary for named selecations and values. + + +.. function:: generate_network_connections(recipe, context, decomp) + + Generate network connections from the network description in the recipe. Only generates connections + with local gids in the domain composition as target. Does not include connections from + the "connections_on" recipe function. + +.. function:: generate_network_connections(recipe) + + Generate network connections from the network description in the recipe. Returns all generated connections on every process. + Does not include connections from the "connections_on" recipe function. diff --git a/doc/python/interconnectivity.rst b/doc/python/interconnectivity.rst index 2063f412ef..28edc9c114 100644 --- a/doc/python/interconnectivity.rst +++ b/doc/python/interconnectivity.rst @@ -7,15 +7,15 @@ Interconnectivity .. class:: connection - Describes a connection between two cells, defined by source and destination end points (that is pre-synaptic and + Describes a connection between two cells, defined by source and target end points (that is pre-synaptic and post-synaptic respectively), a connection weight and a delay time. - The :attr:`dest` does not include the gid of a cell, this is because a :class:`arbor.connection` is bound to the - destination cell which means that the gid is implicitly known. + The :attr:`target` does not include the gid of a cell, this is because a :class:`arbor.connection` is bound to the + target cell which means that the gid is implicitly known. - .. function:: connection(source, destination, weight, delay) + .. function:: connection(source, target, weight, delay) - Construct a connection between the :attr:`source` and the :attr:`dest` with a :attr:`weight` and :attr:`delay`. + Construct a connection between the :attr:`source` and the :attr:`target` with a :attr:`weight` and :attr:`delay`. .. attribute:: source @@ -23,10 +23,10 @@ Interconnectivity (gid, label) or a (gid, (label, policy)) tuple. If the policy is not indicated, the default :attr:`arbor.selection_policy.univalent` is used). - .. attribute:: dest + .. attribute:: target - The destination end point of the connection (type: :class:`arbor.cell_local_label` representing the label of the - destination on the cell, which can be initialized with just a label, in which case the default + The target end point of the connection (type: :class:`arbor.cell_local_label` representing the label of the + target on the cell, which can be initialized with just a label, in which case the default :attr:`arbor.selection_policy.univalent` is used, or a (label, policy) tuple). The gid of the cell is implicitly known. @@ -63,18 +63,18 @@ Interconnectivity def connections_on(gid): # construct a connection from the "detector" source label on cell 2 # to the "syn" target label on cell gid with weight 0.01 and delay of 10 ms. - src = (2, "detector") # gid and locset label of the source - dest = "syn" # gid of the destination is determined by the argument to `connections_on`. + source = (2, "detector") # gid and locset label of the source + target = "syn" # gid of the target is determined by the argument to `connections_on`. w = 0.01 # weight of the connection. Correspondes to 0.01 μS on expsyn mechanisms d = 10 * arbor.units.ms # delay - return [arbor.connection(src, dest, w, d)] + return [arbor.connection(source, target, w, d)] .. class:: gap_junction_connection Describes a gap junction between two gap junction sites. The :attr:`local` site does not include the gid of a cell, this is because a :class:`arbor.gap_junction_connection` - is bound to the destination cell which means that the gid is implicitly known. + is bound to the target cell which means that the gid is implicitly known. .. note:: @@ -96,7 +96,7 @@ Interconnectivity .. attribute:: local The gap junction site: the local half of the gap junction connection (type: :class:`arbor.cell_local_label` - representing the label of the destination on the cell, which can be initialized with just a label, in which case + representing the label of the target on the cell, which can be initialized with just a label, in which case the default :attr:`arbor.selection_policy.univalent` is used, or a (label, policy) tuple). The gid of the cell is implicitly known. @@ -112,3 +112,70 @@ Interconnectivity .. attribute:: threshold Voltage threshold of threshold detector [mV] + + +.. class:: network_site_info + + A network connection site on a cell. Used for generated connections through the high-level network description. + + .. attribute:: gid + + The cell index. + + .. attribute:: kind + + The cell kind. + + .. attribute:: label + + The associated label. + + .. attribute:: location + + The local location on the cell. + + .. attribute:: global_location + + The global location in cartesian coordinates. + + +.. class:: network_connection_info + + A network connection between cells. Used for generated connections through the high-level network description. + + .. attribute:: source + + The source connection site. + + .. attribute:: target + + The target connection site. + + +.. class:: network_description + + A complete network description required for processing. + + .. attribute:: selection + + Selection of connections. + + .. attribute:: weight + + Weight of generated connections. + + .. attribute:: delay + + Delay of generated connections. + + .. attribute:: dict + + Dictionary for named selecations and values. + + +.. function:: generate_network_connections(recipe, context = None, decomp = None) + + Generate network connections from the network description in the recipe. A distributed context and + domain decomposition can optionally be provided. Only generates connections with local gids in the + domain composition as target. Will return all connections on every process, if no context and domain + decomposition are provided. Does not include connections from the "connections_on" recipe function. diff --git a/doc/python/recipe.rst b/doc/python/recipe.rst index 6acc390867..2aab0b3e38 100644 --- a/doc/python/recipe.rst +++ b/doc/python/recipe.rst @@ -84,7 +84,12 @@ Recipe By default returns an empty list. + .. function:: network_description() + Returns a network description, consisting of a network selection, network value for + weight and delay, and a network dictionary. + + By default returns none. .. function:: gap_junctions_on(gid) @@ -122,6 +127,12 @@ Recipe By default returns an empty object. + .. function:: cell_isometry(gid) + + Returns a isometry consisting of translation and rotation, which is applied to the cell morphology for resolving global locations. + + By default returns a isometry without translation and rotation. + Cells ------ diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index dfbe051e49..3c5c50a108 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -17,6 +17,7 @@ add_subdirectory(lfp) add_subdirectory(diffusion) add_subdirectory(v_clamp) add_subdirectory(ornstein_uhlenbeck) +add_subdirectory(network_description) if(ARB_WITH_MPI) add_subdirectory(remote) diff --git a/example/network_description/CMakeLists.txt b/example/network_description/CMakeLists.txt new file mode 100644 index 0000000000..c4db25b297 --- /dev/null +++ b/example/network_description/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(network_description EXCLUDE_FROM_ALL network_description.cpp) +add_dependencies(examples network_description) + +target_link_libraries(network_description PRIVATE arbor arborio arborenv arbor-sup ${json_library_name}) diff --git a/example/network_description/branch_cell.hpp b/example/network_description/branch_cell.hpp new file mode 100644 index 0000000000..92aa9f6119 --- /dev/null +++ b/example/network_description/branch_cell.hpp @@ -0,0 +1,132 @@ +#pragma once + +#include +#include + +#include + +#include + +#include +#include +#include +#include + +#include +#include + +using namespace arborio::literals; + +namespace U = arb::units; + +// Parameters used to generate the random cell morphologies. +struct cell_parameters { + cell_parameters() = default; + + // Maximum number of levels in the cell (not including the soma) + unsigned max_depth = 5; + + // The following parameters are described as ranges. + // The first value is at the soma, and the last value is used on the last level. + // Values at levels in between are found by linear interpolation. + std::array branch_probs = {1.0, 0.5}; // Probability of a branch occuring. + std::array compartments = {20, 2}; // Compartment count on a branch. + std::array lengths = {200, 20}; // Length of branch in μm. + + // The number of synapses per cell. + unsigned synapses = 1; +}; + +inline cell_parameters parse_cell_parameters(nlohmann::json& json) { + cell_parameters params; + sup::param_from_json(params.max_depth, "depth", json); + sup::param_from_json(params.branch_probs, "branch-probs", json); + sup::param_from_json(params.compartments, "compartments", json); + sup::param_from_json(params.lengths, "lengths", json); + sup::param_from_json(params.synapses, "synapses", json); + + return params; +} + +// Helper used to interpolate in branch_cell. +template +double interp(const std::array& r, unsigned i, unsigned n) { + double p = i * 1./(n-1); + double r0 = r[0]; + double r1 = r[1]; + return r[0] + p*(r1-r0); +} + +inline arb::cable_cell branch_cell(arb::cell_gid_type gid, const cell_parameters& params) { + arb::segment_tree tree; + + // Add soma. + double srad = 12.6157/2.0; // soma radius + int stag = 1; // soma tag + tree.append(arb::mnpos, {0, 0,-srad, srad}, {0, 0, srad, srad}, stag); // For area of 500 μm². + + std::vector> levels; + levels.push_back({0}); + + // Standard mersenne_twister_engine seeded with gid. + std::mt19937 gen(gid); + std::uniform_real_distribution dis(0, 1); + + double drad = 0.5; // Diameter of 1 μm for each dendrite cable. + int dtag = 3; // Dendrite tag. + + double dist_from_soma = srad; // Start dendrite at the edge of the soma. + for (unsigned i=0; i sec_ids; + for (unsigned sec: levels[i]) { + for (unsigned j=0; j<2; ++j) { + if (dis(gen) 1) { + decor.place(arb::ls::uniform("dend"_lab, 0, params.synapses - 2, gid), arb::synapse("expsyn"), "extra_syns"); + } + + // Make a CV between every sample in the sample tree. + decor.set_default(arb::cv_policy_every_segment()); + + arb::cable_cell cell(arb::morphology(tree), decor, labels); + + return cell; +} diff --git a/example/network_description/network_description.cpp b/example/network_description/network_description.cpp new file mode 100644 index 0000000000..8f80814be7 --- /dev/null +++ b/example/network_description/network_description.cpp @@ -0,0 +1,343 @@ +/* + * A miniapp that demonstrates how to use network expressions + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include "branch_cell.hpp" + +#ifdef ARB_MPI_ENABLED +#include +#include +#endif + +struct ring_params { + ring_params() = default; + + std::string name = "default"; + unsigned num_cells = 20; + double min_delay = 10; + double duration = 100; + cell_parameters cell; +}; + +ring_params read_options(int argc, char** argv); +using arb::cell_gid_type; +using arb::cell_kind; +using arb::cell_lid_type; +using arb::cell_member_type; +using arb::cell_size_type; +using arb::time_type; + +// Writes voltage trace as a json file. +void write_trace_json(const arb::trace_data& trace); + +// Generate a cell. +arb::cable_cell branch_cell(arb::cell_gid_type gid, const cell_parameters& params); + +class ring_recipe: public arb::recipe { +public: + ring_recipe(unsigned num_cells, cell_parameters params, unsigned min_delay): + num_cells_(num_cells), + cell_params_(params), + min_delay_(min_delay) { + gprop_.default_parameters = arb::neuron_parameter_defaults; + } + + cell_size_type num_cells() const override { return num_cells_; } + + arb::util::unique_any get_cell_description(cell_gid_type gid) const override { + return branch_cell(gid, cell_params_); + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { return cell_kind::cable; } + + arb::isometry get_cell_isometry(cell_gid_type gid) const override { + // place cells with equal distance on a circle + const double angle = 2 * 3.1415926535897932 * gid / num_cells_; + const double radius = 500.0; + return arb::isometry::translate(radius * std::cos(angle), radius * std::sin(angle), 0.0); + }; + + std::optional network_description() const override { + // create a chain + auto chain = arb::network_selection::chain(arb::gid_range(0, num_cells_)); + // connect front and back of chain to form ring + auto ring = arb::join(chain, + arb::intersect(arb::network_selection::source_cell({num_cells_ - 1}), + arb::network_selection::target_cell({0}))); + + // Create random connections with probability inversely proportional to the distance within + // a radius + const double max_dist = 400.0; + auto probability = (max_dist - arb::network_value::distance()) / max_dist; + + // restrict to inter-cell connections and to distance within radius + auto seed = 42; + auto rand = intersect(arb::network_selection::random(seed, probability), + arb::network_selection::distance_lt(max_dist), + arb::network_selection::inter_cell()); + + // combine ring with random selection + auto s = join(ring, rand); + + // restrict to certain source and target labels + s = arb::intersect(s, + arb::network_selection::source_label({"detector"}), + arb::network_selection::target_label({"primary_syn"})); + + // random normal distributed weight with mean 0.05 μS, standard deviation 0.02 μS + // and truncated to [0.025, 0.075] + auto w_rand = "(truncated-normal-distribution 42 0.05 0.02 0.025 0.075)"_nv; + // note: We are using s-expressions here as an alternative for creating a network_value. + // This alternative way is also available for network selections. + + // fixed weight for connections in ring + auto w_ring = "(scalar 0.01)"_nv; + + // combine into single weight by using the "ring" selection as condition + auto w = arb::network_value::if_else(ring, w_ring, w_rand); + + return arb::network_description{s, w, min_delay_, {}}; + }; + + // Return one event generator on gid 0. This generates a single event that will + // kick start the spiking. + std::vector event_generators(cell_gid_type gid) const override { + std::vector gens; + if (!gid) { + gens.push_back(arb::explicit_generator_from_milliseconds( + {"primary_syn"}, event_weight_, std::vector{1.0})); + } + return gens; + } + + std::vector get_probes(cell_gid_type gid) const override { + // Measure membrane voltage at end of soma. + arb::mlocation loc{0, 0.0}; + return {{arb::cable_probe_membrane_voltage{loc}, "Um"}}; + } + + std::any get_global_properties(arb::cell_kind) const override { return gprop_; } + +private: + cell_size_type num_cells_; + cell_parameters cell_params_; + double min_delay_; + float event_weight_ = 0.05; + arb::cable_cell_global_properties gprop_; +}; + +int main(int argc, char** argv) { + try { + bool root = true; + + arb::proc_allocation resources; + resources.num_threads = arbenv::default_concurrency(); + +#ifdef ARB_MPI_ENABLED + arbenv::with_mpi guard(argc, argv, false); + resources.gpu_id = arbenv::find_private_gpu(MPI_COMM_WORLD); + auto context = arb::make_context(resources, MPI_COMM_WORLD); + root = arb::rank(context) == 0; +#else + resources.gpu_id = arbenv::default_gpu(); + auto context = arb::make_context(resources); +#endif + +#ifdef ARB_PROFILE_ENABLED + arb::profile::profiler_initialize(context); +#endif + + std::cout << sup::mask_stream(root); + + // Print a banner with information about hardware configuration + std::cout << "gpu: " << (has_gpu(context) ? "yes" : "no") << "\n"; + std::cout << "threads: " << num_threads(context) << "\n"; + std::cout << "mpi: " << (has_mpi(context) ? "yes" : "no") << "\n"; + std::cout << "ranks: " << num_ranks(context) << "\n" << std::endl; + + auto params = read_options(argc, argv); + + arb::profile::meter_manager meters; + meters.start(context); + + // Create an instance of our recipe. + ring_recipe recipe(params.num_cells, params.cell, params.min_delay); + + // Construct the model. + auto decomposition = arb::partition_load_balance(recipe, context); + arb::simulation sim(recipe, context, decomposition); + + // Set up the probe that will measure voltage in the cell. + + // The id of the only probe on the cell: the cell_member type points to (cell 0, probe 0) + auto probeset_id = arb::cell_address_type{0, "Um"}; + // The schedule for sampling is 10 samples every 1 ms. + auto sched = arb::regular_schedule(1*arb::units::ms); + // This is where the voltage samples will be stored as (time, value) pairs + arb::trace_vector voltage; + // Now attach the sampler at probeset_id, with sampling schedule sched, writing to voltage + sim.add_sampler(arb::one_probe(probeset_id), sched, arb::make_simple_sampler(voltage)); + + // Set up recording of spikes to a vector on the root process. + std::vector recorded_spikes; + if (root) { + sim.set_global_spike_callback( + [&recorded_spikes](const std::vector& spikes) { + recorded_spikes.insert(recorded_spikes.end(), spikes.begin(), spikes.end()); + }); + } + + meters.checkpoint("model-init", context); + + if (root) { sim.set_epoch_callback(arb::epoch_progress_bar()); } + std::cout << "running simulation\n" << std::endl; + // Run the simulation for 100 ms, with time steps of 0.025 ms. + sim.run(params.duration*arb::units::ms, 0.025*arb::units::ms); + + meters.checkpoint("model-run", context); + + // Print generated connections + if (root) { + const auto connections = arb::generate_network_connections(recipe); + std::cout << "Connections:" << std::endl; + for (const auto& c: connections) { + std::cout << "(" << c.source.gid << ", \"" << c.source.label << "\") ->"; + std::cout << "(" << c.target.gid << ", \"" << c.target.label << "\")" << std::endl; + } + } + + auto ns = sim.num_spikes(); + + // Write spikes to file + if (root) { + std::cout << "\n" + << ns << " spikes generated at rate of " << params.duration / ns + << " ms between spikes\n"; + std::ofstream fid("spikes.gdf"); + if (!fid.good()) { + std::cerr << "Warning: unable to open file spikes.gdf for spike output\n"; + } + else { + char linebuf[45]; + for (auto spike: recorded_spikes) { + auto n = std::snprintf(linebuf, + sizeof(linebuf), + "%u %.4f\n", + unsigned{spike.source.gid}, + float(spike.time)); + fid.write(linebuf, n); + } + } + } + + // Write the samples to a json file. + if (root) { write_trace_json(voltage.at(0)); } + + auto profile = arb::profile::profiler_summary(); + std::cout << profile << "\n"; + + auto report = arb::profile::make_meter_report(meters, context); + std::cout << report; + } + catch (std::exception& e) { + std::cerr << "exception caught in ring miniapp: " << e.what() << "\n"; + return 1; + } + + return 0; +} + +void write_trace_json(const arb::trace_data& trace) { + std::string path = "./voltages.json"; + + nlohmann::json json; + json["name"] = "ring demo"; + json["units"] = "mV"; + json["cell"] = "0.0"; + json["probe"] = "0"; + + auto& jt = json["data"]["time"]; + auto& jy = json["data"]["voltage"]; + + for (const auto& sample: trace) { + jt.push_back(sample.t); + jy.push_back(sample.v); + } + + std::ofstream file(path); + file << std::setw(1) << json << "\n"; +} + +ring_params read_options(int argc, char** argv) { + using sup::param_from_json; + + ring_params params; + if (argc < 2) { + std::cout << "Using default parameters.\n"; + return params; + } + if (argc > 2) { + throw std::runtime_error("More than one command line option is not permitted."); + } + + std::string fname = argv[1]; + std::cout << "Loading parameters from file: " << fname << "\n"; + std::ifstream f(fname); + + if (!f.good()) { throw std::runtime_error("Unable to open input parameter file: " + fname); } + + nlohmann::json json; + f >> json; + + param_from_json(params.name, "name", json); + param_from_json(params.num_cells, "num-cells", json); + param_from_json(params.duration, "duration", json); + param_from_json(params.min_delay, "min-delay", json); + params.cell = parse_cell_parameters(json); + + if (!json.empty()) { + for (auto it = json.begin(); it != json.end(); ++it) { + std::cout << " Warning: unused input parameter: \"" << it.key() << "\"\n"; + } + std::cout << "\n"; + } + + return params; +} diff --git a/example/network_description/readme.md b/example/network_description/readme.md new file mode 100644 index 0000000000..97131ae1a6 --- /dev/null +++ b/example/network_description/readme.md @@ -0,0 +1,3 @@ +# Ring Example + +A miniapp that demonstrates how to describe how to build a simple ring network with random interconnection using the network description language. diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 21b0bdc16f..c590da1f12 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -34,6 +34,7 @@ set(pyarb_source mechanism.cpp morphology.cpp mpi.cpp + network.cpp profiler.cpp pyarb.cpp label_dict.cpp diff --git a/python/example/network_description.py b/python/example/network_description.py new file mode 100755 index 0000000000..a0cf01a863 --- /dev/null +++ b/python/example/network_description.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# This script is included in documentation. Adapt line numbers if touched. + +import arbor as A +from arbor import units as U +import pandas # You may have to pip install these +import seaborn # You may have to pip install these +from math import sqrt + + +def make_cable_cell(gid): + # (1) Build a segment tree + # The dendrite (dend) attaches to the soma and has two simple segments + # attached. + # + # left + # / + # soma - dend + # \ + # right + tree = A.segment_tree() + root = A.mnpos + # Soma (tag=1) with radius 6 μm, modelled as cylinder of length 2*radius + soma = tree.append(root, (-12, 0, 0, 6), (0, 0, 0, 6), tag=1) + # Single dendrite (tag=3) of length 50 μm and radius 2 μm attached to soma. + dend = tree.append(soma, (0, 0, 0, 2), (50, 0, 0, 2), tag=3) + # Attach two dendrites (tag=3) of length 50 μm to the end of the first dendrite. + # Radius tapers from 2 to 0.5 μm over the length of the dendrite. + l = 50 / sqrt(2) + _ = tree.append( + dend, + (50, 0, 0, 2), + (50 + l, l, 0, 0.5), + tag=3, + ) + # Constant radius of 1 μm over the length of the dendrite. + _ = tree.append( + dend, + (50, 0, 0, 1), + (50 + l, -l, 0, 1), + tag=3, + ) + + # Associate labels to tags + labels = A.label_dict( + { + "soma": "(tag 1)", + "dend": "(tag 3)", + # (2) Mark location for synapse at the midpoint of branch 1 (the first dendrite). + "synapse_site": "(location 1 0.5)", + # Mark the root of the tree. + "root": "(root)", + } + ) + + # (3) Create a decor and a cable_cell + decor = ( + A.decor() + # Put hh dynamics on soma, and passive properties on the dendrites. + .paint('"soma"', A.density("hh")).paint('"dend"', A.density("pas")) + # (4) Attach a single synapse. + .place('"synapse_site"', A.synapse("expsyn"), "syn") + # Attach a detector with threshold of -10 mV. + .place('"root"', A.threshold_detector(-10 * U.mV), "detector") + ) + + return A.cable_cell(tree, decor, labels) + + +# (5) Create a recipe that generates a network of connected cells. +class random_ring_recipe(A.recipe): + def __init__(self, ncells): + # Base class constructor must be called first for proper initialization. + A.recipe.__init__(self) + self.ncells = ncells + self.props = A.neuron_cable_properties() + + # (6) Returns the total number of cells in the model; must be implemented. + def num_cells(self): + return self.ncells + + # (7) The cell_description method returns a cell + def cell_description(self, gid): + return make_cable_cell(gid) + + # Return the type of cell; must be implemented and match cell_description. + def cell_kind(self, _): + return A.cell_kind.cable + + # (8) Descripe network + def network_description(self): + seed = 42 + + # create a chain + chain = f"(chain (gid-range 0 {self.ncells}))" + # connect front and back of chain to form ring + ring = f"(join {chain} (intersect (source-cell {self.ncells - 1}) (target-cell 0)))" + + # Create random connections with probability inversely proportional to the distance within a + # radius + max_dist = 400.0 # μm + probability = f"(div (sub {max_dist} (distance)) {max_dist})" + rand = f"(intersect (random {seed} {probability}) (distance-lt {max_dist}))" + + # combine ring with random selection + s = f"(join {ring} {rand})" + # restrict to inter-cell connections and certain source / target labels + s = f'(intersect {s} (inter-cell) (source-label "detector") (target-label "syn"))' + + # fixed weight for connections in ring + w_ring = "(scalar 0.01)" + # random normal distributed weight with mean 0.02 μS, standard deviation 0.01 μS + # and truncated to [0.005, 0.035] + w_rand = f"(truncated-normal-distribution {seed} 0.02 0.01 0.005 0.035)" + + # combine into single weight expression + w = f"(if-else {ring} {w_ring} {w_rand})" + + # fixed delay + d = "(scalar 5.0)" # ms delay + + return A.network_description(s, w, d, {}) + + # (9) Attach a generator to the first cell in the ring. + def event_generators(self, gid): + if gid == 0: + sched = A.explicit_schedule([1 * U.ms]) # one event at 1 ms + weight = 0.1 # 0.1 μS on expsyn + return [A.event_generator("syn", weight, sched)] + return [] + + # (10) Place a probe at the root of each cell. + def probes(self, gid): + return [A.cable_probe_membrane_voltage('"root"', "Um")] + + def global_properties(self, _): + return self.props + + +# (11) Instantiate recipe +ncells = 4 +recipe = random_ring_recipe(ncells) + +# (12) Create a simulation using the default settings: +# - Use all threads available +# - Use round-robin distribution of cells across groups with one cell per group +# - Use GPU if present +# - No MPI +# Other constructors of simulation can be used to change all of these. +sim = A.simulation(recipe) + +# (13) Set spike generators to record +sim.record(A.spike_recording.all) + +# (14) Attach a sampler to the voltage probe on cell 0. Sample rate of 10 sample every ms. +handles = [ + sim.sample((gid, "Um"), A.regular_schedule(0.1 * U.ms)) for gid in range(ncells) +] + +# (15) Run simulation for 100 ms +sim.run(100 * U.ms) +print("Simulation finished") + +# (16) Print spike times +print("spikes:") +for sp in sim.spikes(): + print(" ", sp) + +# (17) Plot the recorded voltages over time. +print("Plotting results ...") +dfs = [] +for gid in range(ncells): + samples, meta = sim.samples(handles[gid])[0] + dfs.append( + pandas.DataFrame( + {"t/ms": samples[:, 0], "U/mV": samples[:, 1], "Cell": f"cell {gid}"} + ) + ) +df = pandas.concat(dfs, ignore_index=True) +seaborn.relplot( + data=df, kind="line", x="t/ms", y="U/mV", hue="Cell", errorbar=None +).savefig("network_description_result.svg") diff --git a/python/network.cpp b/python/network.cpp new file mode 100644 index 0000000000..d1bf3858b5 --- /dev/null +++ b/python/network.cpp @@ -0,0 +1,163 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "context.hpp" +#include "error.hpp" +#include "recipe.hpp" +#include "strprintf.hpp" +#include "util.hpp" + +namespace py = pybind11; + +namespace pyarb { + +void register_network(py::module& m) { + using namespace py::literals; + + py::class_ network_site_info( + m, "network_site_info", "Identifies a network site to connect to / from"); + network_site_info.def_readwrite("gid", &arb::network_site_info::gid) + .def_readwrite("kind", &arb::network_site_info::kind) + .def_readwrite("label", &arb::network_site_info::label) + .def_readwrite("location", &arb::network_site_info::location) + .def_readwrite("global_location", &arb::network_site_info::global_location) + .def("__repr__", [](const arb::network_site_info& s) { return util::pprintf("{}", s); }) + .def("__str__", [](const arb::network_site_info& s) { return util::pprintf("{}", s); }); + + py::class_ network_connection_info( + m, "network_connection_info", "Identifies a network connection"); + network_connection_info.def_readwrite("source", &arb::network_connection_info::source) + .def_readwrite("target", &arb::network_connection_info::target) + .def_readwrite("weight", &arb::network_connection_info::weight) + .def_readwrite("delay", &arb::network_connection_info::delay) + .def("__repr__", + [](const arb::network_connection_info& c) { return util::pprintf("{}", c); }) + .def("__str__", + [](const arb::network_connection_info& c) { return util::pprintf("{}", c); }); + + py::class_ network_selection( + m, "network_selection", "Network selection."); + + network_selection + .def_static("custom", + [](arb::network_selection::custom_func_type func) { + return arb::network_selection::custom([=](const arb::network_site_info& source, const arb::network_site_info& target) { + return try_catch_pyexception( + [&]() { + pybind11::gil_scoped_acquire guard; + return func(source, target); + }, + "Python error already thrown"); + }); + }) + .def("__str__", + [](const arb::network_selection& s) { + return util::pprintf("", s); + }) + .def("__repr__", [](const arb::network_selection& s) { return util::pprintf("{}", s); }); + + py::class_ network_value(m, "network_value", "Network value."); + network_value + .def_static("custom", + [](arb::network_value::custom_func_type func) { + return arb::network_value::custom([=](const arb::network_site_info& source, const arb::network_site_info& target) { + return try_catch_pyexception( + [&]() { + pybind11::gil_scoped_acquire guard; + return func(source, target); + }, + "Python error already thrown"); + }); + }) + .def("__str__", + [](const arb::network_value& v) { + return util::pprintf("", v); + }) + .def("__repr__", [](const arb::network_value& v) { return util::pprintf("{}", v); }); + + py::class_ network_description( + m, "network_description", "Network description."); + network_description.def( + py::init( + [](std::string selection, + std::string weight, + std::string delay, + std::unordered_map> map) { + arb::network_label_dict dict; + for (const auto& [label, v]: map) { + const auto& dict_label = label; + std::visit( + arb::util::overload( + [&](const std::string& s) { + auto sel = arborio::parse_network_selection_expression(s); + if (sel) { + dict.set(dict_label, *sel); + return; + } + + auto val = arborio::parse_network_value_expression(s); + if (val) { + dict.set(dict_label, *val); + return; + } + + throw pyarb_error( + std::string("Failed to parse \"") + dict_label + + "\" label in dict of network description. \nSelection " + "label parse error:\n" + + sel.error().what() + "\nValue label parse error:\n" + + val.error().what()); + }, + [&](const arb::network_selection& sel) { dict.set(dict_label, sel); }, + [&](const arb::network_value& val) { dict.set(dict_label, val); }), + v); + } + auto desc = arb::network_description{ + arborio::parse_network_selection_expression(selection).unwrap(), + arborio::parse_network_value_expression(weight).unwrap(), + arborio::parse_network_value_expression(delay).unwrap(), + dict}; + return desc; + }), + "selection"_a, + "weight"_a, + "delay"_a, + "dict"_a, + "Construct network description."); + + m.def( + "generate_network_connections", + [](const std::shared_ptr& rec, + std::shared_ptr ctx, + std::optional decomp) { + recipe_shim rec_shim(rec); + + if (!ctx) ctx = std::make_shared(arb::make_context()); + if (!decomp) decomp = arb::partition_load_balance(rec_shim, ctx->context); + + return generate_network_connections(rec_shim, ctx->context, decomp.value()); + }, + "recipe"_a, + "context"_a = pybind11::none(), + "decomp"_a = pybind11::none(), + "Generate network connections from the network description in the recipe. Will only " + "generate connections with local gids in the domain composition as target."); +} + +} // namespace pyarb diff --git a/python/pyarb.cpp b/python/pyarb.cpp index 7f77d673a2..d4c14f0766 100644 --- a/python/pyarb.cpp +++ b/python/pyarb.cpp @@ -30,6 +30,7 @@ void register_recipe(pybind11::module& m); void register_schedules(pybind11::module& m); void register_simulation(pybind11::module& m, pyarb_global_ptr); void register_arborenv(pybind11::module& m); +void register_network(pybind11::module& m); void register_single_cell(pybind11::module& m); void register_units(pybind11::module& m); void register_label_dict(pybind11::module& m); @@ -70,6 +71,7 @@ PYBIND11_MODULE(_arbor, m) { pyarb::register_simulation(m, global_ptr); pyarb::register_arborenv(m); pyarb::register_single_cell(m); + pyarb::register_network(m); // This is the fallback. All specific translators take precedence by being // registered *later*. diff --git a/python/recipe.cpp b/python/recipe.cpp index 667be8dec3..a49209220a 100644 --- a/python/recipe.cpp +++ b/python/recipe.cpp @@ -205,6 +205,10 @@ void register_recipe(pybind11::module& m) { .def("gap_junctions_on", &recipe::gap_junctions_on, "gid"_a, "A list of the gap junctions connected to gid, [] by default.") + .def("network_description", &recipe::network_description, + "Network description of cell connections.") + .def("cell_isometry", &recipe::cell_isometry, + "Isometry describing translation and rotation of cell.") .def("probes", &recipe::probes, "gid"_a, "The probes to allow monitoring.") diff --git a/python/recipe.hpp b/python/recipe.hpp index 9e0adc2a84..3ecd1aa312 100644 --- a/python/recipe.hpp +++ b/python/recipe.hpp @@ -1,13 +1,16 @@ #pragma once #include +#include #include #include #include -#include #include +#include +#include +#include #include #include "error.hpp" @@ -50,6 +53,12 @@ class recipe { virtual pybind11::object global_properties(arb::cell_kind kind) const { return pybind11::none(); }; + virtual std::optional network_description() const { + return std::nullopt; + }; + virtual arb::isometry cell_isometry(arb::cell_gid_type gid) const { + return arb::isometry(); + }; }; class recipe_trampoline: public recipe { @@ -87,6 +96,14 @@ class recipe_trampoline: public recipe { PYBIND11_OVERRIDE(std::vector, recipe, gap_junctions_on, gid); } + std::optional network_description() const override { + PYBIND11_OVERRIDE(std::optional, recipe, network_description); + } + + arb::isometry cell_isometry(arb::cell_gid_type gid) const override { + PYBIND11_OVERRIDE(arb::isometry, recipe, cell_isometry, gid); + }; + std::vector probes(arb::cell_gid_type gid) const override { PYBIND11_OVERRIDE(std::vector, recipe, probes, gid); } @@ -140,6 +157,14 @@ class recipe_shim: public ::arb::recipe { } std::any get_global_properties(arb::cell_kind kind) const override; + + std::optional network_description() const override { + return try_catch_pyexception([&]() { return impl_->network_description(); }, msg); + }; + + arb::isometry get_cell_isometry(arb::cell_gid_type gid) const override { + return try_catch_pyexception([&]() { return impl_->cell_isometry(gid); }, msg); + }; }; } // namespace pyarb diff --git a/scripts/run_cpp_examples.sh b/scripts/run_cpp_examples.sh index fd07affd11..b916c23594 100755 --- a/scripts/run_cpp_examples.sh +++ b/scripts/run_cpp_examples.sh @@ -36,6 +36,7 @@ all_examples=( "plasticity" "ou" "voltage-clamp" + "network_description" "remote" ) @@ -58,6 +59,7 @@ expected_outputs=( "" "" "" + 46 "" ) diff --git a/scripts/run_python_examples.sh b/scripts/run_python_examples.sh index 6fb9e0e5cf..f92cbe16cf 100755 --- a/scripts/run_python_examples.sh +++ b/scripts/run_python_examples.sh @@ -35,6 +35,7 @@ runpyex network_ring.py # runpyex network_ring_mpi_plot.py # no need to test runpyex network_ring_gpu.py # by default, gpu_id=None runpyex network_two_cells_gap_junctions.py +runpyex network_ring.py runpyex diffusion.py runpyex plasticity.py runpyex v-clamp.py diff --git a/test/unit-distributed/CMakeLists.txt b/test/unit-distributed/CMakeLists.txt index 545d18e021..96ea5abf4b 100644 --- a/test/unit-distributed/CMakeLists.txt +++ b/test/unit-distributed/CMakeLists.txt @@ -3,6 +3,8 @@ set(unit-distributed_sources test_domain_decomposition.cpp test_communicator.cpp test_mpi.cpp + test_distributed_for_each.cpp + test_network_generation.cpp # unit test driver test.cpp @@ -13,7 +15,7 @@ add_dependencies(tests unit-local) target_compile_options(unit-local PRIVATE ${ARB_CXX_FLAGS_TARGET_FULL}) target_compile_definitions(unit-local PRIVATE TEST_LOCAL) -target_link_libraries(unit-local PRIVATE ext-gtest arbor arborenv arbor-sup arbor-private-headers ext-tinyopt) +target_link_libraries(unit-local PRIVATE ext-gtest arbor arborenv arborio arbor-sup arbor-private-headers ext-tinyopt) if(ARB_WITH_MPI) add_executable(unit-mpi EXCLUDE_FROM_ALL ${unit-distributed_sources}) @@ -21,6 +23,6 @@ if(ARB_WITH_MPI) target_compile_options(unit-mpi PRIVATE ${ARB_CXX_FLAGS_TARGET_FULL}) target_compile_definitions(unit-mpi PRIVATE TEST_MPI) - target_link_libraries(unit-mpi PRIVATE ext-gtest arbor arborenv arbor-sup arbor-private-headers ext-tinyopt) + target_link_libraries(unit-mpi PRIVATE ext-gtest arbor arborenv arborio arbor-sup arbor-private-headers ext-tinyopt) endif() diff --git a/test/unit-distributed/test_communicator.cpp b/test/unit-distributed/test_communicator.cpp index 0100a30e0b..0e46459e04 100644 --- a/test/unit-distributed/test_communicator.cpp +++ b/test/unit-distributed/test_communicator.cpp @@ -216,7 +216,7 @@ namespace { std::vector connections_on(cell_gid_type gid) const override { // a single connection from the preceding cell, i.e. a ring - // weight is the destination gid + // weight is the target gid // delay is 1 cell_global_label_type src = {gid==0? size_-1: gid-1, "src"}; cell_local_label_type dst = {"tgt"}; @@ -529,7 +529,7 @@ TEST(communicator, ring) auto global_sources = g_context->distributed->gather_cell_labels_and_gids(local_sources); // construct the communicator - auto C = communicator(R, D, *g_context); + auto C = communicator(R, D, g_context); C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map(local_targets)); // every cell fires EXPECT_TRUE(test_ring(D, C, [](cell_gid_type g){return true;})); @@ -637,7 +637,7 @@ TEST(communicator, all2all) auto global_sources = g_context->distributed->gather_cell_labels_and_gids({local_sources, mc_gids}); // construct the communicator - auto C = communicator(R, D, *g_context); + auto C = communicator(R, D, g_context); C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, mc_gids})); auto connections = C.connections(); @@ -684,7 +684,7 @@ TEST(communicator, mini_network) auto global_sources = g_context->distributed->gather_cell_labels_and_gids({local_sources, gids}); // construct the communicator - auto C = communicator(R, D, *g_context); + auto C = communicator(R, D, g_context); C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, gids})); // sort connections by source then target diff --git a/test/unit-distributed/test_distributed_for_each.cpp b/test/unit-distributed/test_distributed_for_each.cpp new file mode 100644 index 0000000000..d3965892db --- /dev/null +++ b/test/unit-distributed/test_distributed_for_each.cpp @@ -0,0 +1,92 @@ +#include +#include "test.hpp" + +#include +#include +#include + +#include "communication/distributed_for_each.hpp" +#include "execution_context.hpp" +#include "util/rangeutil.hpp" + +using namespace arb; + +// check when all input is size 0 +TEST(distributed_for_each, all_zero) { + std::vector data; + + const int num_ranks = g_context->distributed->size(); + int call_count = 0; + + auto sample = [&](const util::range& range) { + EXPECT_EQ(0, range.size()); + ++call_count; + }; + + distributed_for_each( + sample, *g_context->distributed, util::make_range(data.begin(), data.end())); + + EXPECT_EQ(num_ranks, call_count); +} + +// check when input on one rank is size 0 +TEST(distributed_for_each, one_zero) { + const auto rank = g_context->distributed->id(); + const int num_ranks = g_context->distributed->size(); + int call_count = 0; + + // test data size is equal to rank id and vector is filled with rank id + std::vector data; + for (int i = 0; i < rank; ++i) { data.push_back(rank); } + + auto sample = [&](const util::range& range) { + const auto origin_rank = range.empty() ? 0 : range.front(); + + EXPECT_EQ(origin_rank, range.size()); + for (const auto& value: range) { EXPECT_EQ(value, origin_rank); } + ++call_count; + }; + + distributed_for_each( + sample, *g_context->distributed, util::make_range(data.begin(), data.end())); + + EXPECT_EQ(num_ranks, call_count); +} + +// check multiple types +TEST(distributed_for_each, multiple) { + const auto rank = g_context->distributed->id(); + const int num_ranks = g_context->distributed->size(); + int call_count = 0; + + std::vector data_1; + std::vector data_2; + std::vector> data_3; + // test data size is equal to rank id + 1and vector is filled with rank id + for (int i = 0; i < rank + 1; ++i) { data_1.push_back(rank); } + // test different data sizes for each type + for (std::size_t i = 0; i < 2 * data_1.size(); ++i) { data_2.push_back(rank); } + for (std::size_t i = 0; i < 3 * data_1.size(); ++i) { data_3.push_back(rank); } + + auto sample = [&](const util::range& range_1, + const util::range& range_2, + const util::range*>& range_3) { + const auto origin_rank = range_1.empty() ? 0 : range_1.front(); + + EXPECT_EQ(origin_rank + 1, range_1.size()); + EXPECT_EQ(range_2.size(), 2 * range_1.size()); + EXPECT_EQ(range_3.size(), 3 * range_1.size()); + for (const auto& value: range_1) { EXPECT_EQ(value, origin_rank); } + for (const auto& value: range_2) { EXPECT_EQ(value, double(origin_rank)); } + for (const auto& value: range_3) { EXPECT_EQ(value, std::complex(origin_rank)); } + ++call_count; + }; + + distributed_for_each(sample, + *g_context->distributed, + util::range_view(data_1), + util::range_view(data_2), + util::range_view(data_3)); + + EXPECT_EQ(num_ranks, call_count); +} diff --git a/test/unit-distributed/test_network_generation.cpp b/test/unit-distributed/test_network_generation.cpp new file mode 100644 index 0000000000..f7e3b55fe3 --- /dev/null +++ b/test/unit-distributed/test_network_generation.cpp @@ -0,0 +1,167 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "execution_context.hpp" +#include "test.hpp" + +using namespace arb; +using namespace arborio::literals; +namespace U = arb::units; + +namespace { +// Create alternatingly a cable, lif and spike source cell with at most one source or target +class network_test_recipe: public arb::recipe { +public: + network_test_recipe(unsigned num_cells, + network_selection selection, + network_value weight, + network_value delay): + num_cells_(num_cells), + selection_(selection), + weight_(weight), + delay_(delay) { + gprop_.default_parameters = arb::neuron_parameter_defaults; + } + + cell_size_type num_cells() const override { return num_cells_; } + + arb::util::unique_any get_cell_description(cell_gid_type gid) const override { + if (gid % 3 == 1) { return lif_cell("source", "target"); } + if (gid % 3 == 2) { return spike_source_cell("spike_source"); } + + // cable cell + int stag = 1; // soma tag + int dtag = 3; // Dendrite tag. + double srad = 12.6157 / 2.0; // soma radius + double drad = 0.5; // Diameter of 1 μm for each dendrite cable. + arb::segment_tree tree; + tree.append( + arb::mnpos, {0, 0, -srad, srad}, {0, 0, srad, srad}, stag); // For area of 500 μm². + tree.append(0, {0, 0, 2 * srad, drad}, dtag); + + arb::label_dict labels; + labels.set("soma", reg::tagged(stag)); + labels.set("dend", reg::tagged(dtag)); + + auto decor = arb::decor{} + .paint("soma"_lab, arb::density("hh")) + .paint("dend"_lab, arb::density("pas")) + .set_default(arb::axial_resistivity{100*U::Ohm*U::cm}) // [Ω·cm] + .place(arb::mlocation{0, 0}, arb::threshold_detector{10*U::mV}, "detector") + .place(arb::mlocation{0, 0.5}, arb::synapse("expsyn"), "primary_syn"); + + return arb::cable_cell(arb::morphology(tree), decor, labels); + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + if (gid % 3 == 1) { return cell_kind::lif; } + if (gid % 3 == 2) { return cell_kind::spike_source; } + + return cell_kind::cable; + } + + arb::isometry get_cell_isometry(cell_gid_type gid) const override { + // place cells with equal distance on a circle + const double angle = 2 * 3.1415926535897932 * gid / num_cells_; + const double radius = 500.0; + return arb::isometry::translate(radius * std::cos(angle), radius * std::sin(angle), 0.0); + }; + + std::optional network_description() const override { + return arb::network_description{selection_, weight_, delay_, {}}; + }; + + std::vector event_generators(cell_gid_type gid) const override { + return {}; + } + + std::vector get_probes(cell_gid_type gid) const override { return {}; } + + std::any get_global_properties(arb::cell_kind) const override { return gprop_; } + +private: + cell_size_type num_cells_; + arb::cable_cell_global_properties gprop_; + network_selection selection_; + network_value weight_, delay_; +}; + +} // namespace + +TEST(network_generation, all) { + const auto& ctx = g_context; + const int num_ranks = ctx->distributed->size(); + + const auto selection = network_selection::all(); + const auto weight = 2.0; + const auto delay = 3.0; + + const auto num_cells = 3 * num_ranks; + + auto rec = network_test_recipe(num_cells, selection, weight, delay); + + const auto decomp = partition_load_balance(rec, ctx); + + const auto connections = generate_network_connections(rec, ctx, decomp); + + std::unordered_map> connections_by_dest; + + for (const auto& c: connections) { + EXPECT_EQ(c.weight, weight); + EXPECT_EQ(c.delay, delay); + connections_by_dest[c.target.gid].emplace_back(c); + } + + for (const auto& group: decomp.groups()) { + const auto num_dest = group.kind == cell_kind::spike_source ? 0 : 1; + for (const auto gid: group.gids) { + EXPECT_EQ(connections_by_dest[gid].size(), num_cells * num_dest); + } + } +} + +TEST(network_generation, cable_only) { + const auto& ctx = g_context; + const int num_ranks = ctx->distributed->size(); + + const auto selection = intersect(network_selection::source_cell_kind(cell_kind::cable), + network_selection::target_cell_kind(cell_kind::cable)); + const auto weight = 2.0; + const auto delay = 3.0; + + const auto num_cells = 3 * num_ranks; + + auto rec = network_test_recipe(num_cells, selection, weight, delay); + + const auto decomp = partition_load_balance(rec, ctx); + + const auto connections = generate_network_connections(rec, ctx, decomp); + + std::unordered_map> connections_by_dest; + + for (const auto& c: connections) { + EXPECT_EQ(c.weight, weight); + EXPECT_EQ(c.delay, delay); + connections_by_dest[c.target.gid].emplace_back(c); + } + + for (const auto& group: decomp.groups()) { + for (const auto gid: group.gids) { + // Only one third is a cable cell + EXPECT_EQ(connections_by_dest[gid].size(), + group.kind == cell_kind::cable ? num_cells / 3 : 0); + } + } +} diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 6da2140a82..b75a4555bc 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -109,6 +109,7 @@ set(unit_sources test_morph_place.cpp test_morph_primitives.cpp test_morph_stitch.cpp + test_network.cpp test_ordered_forest.cpp test_padded.cpp test_partition.cpp @@ -128,6 +129,7 @@ set(unit_sources test_simd.cpp test_simulation.cpp test_span.cpp + test_spatial_tree.cpp test_spike_source.cpp test_spikes.cpp test_spike_store.cpp diff --git a/test/unit/test_domain_decomposition.cpp b/test/unit/test_domain_decomposition.cpp index 915a346ae9..ac336616e6 100644 --- a/test/unit/test_domain_decomposition.cpp +++ b/test/unit/test_domain_decomposition.cpp @@ -163,6 +163,15 @@ struct dummy_context { cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { throw unimplemented{__FUNCTION__}; } cell_labels_and_gids gather_cell_labels_and_gids(const cell_labels_and_gids& local_labels_and_gids) const { throw unimplemented{__FUNCTION__}; } template std::vector gather(T value, int) const { throw unimplemented{__FUNCTION__}; } + distributed_request send_recv_nonblocking(std::size_t dest_count, + void* dest_data, + int dest, + std::size_t source_count, + const void* source_data, + int source, + int tag) const { + throw unimplemented{__FUNCTION__}; + } int id() const { return id_; } int size() const { return size_; } diff --git a/test/unit/test_network.cpp b/test/unit/test_network.cpp new file mode 100644 index 0000000000..5005d76775 --- /dev/null +++ b/test/unit/test_network.cpp @@ -0,0 +1,817 @@ +#include + +#include +#include + +#include "network_impl.hpp" + +#include +#include + +using namespace arb; + +namespace { +std::vector test_sites = { + {0, cell_kind::cable, hash_value("a"), {1, 0.5}, {0.0, 0.0, 0.0}}, + {1, cell_kind::benchmark, hash_value("b"), {0, 0.0}, {1.0, 0.0, 0.0}}, + {2, cell_kind::lif, hash_value("c"), {0, 0.0}, {2.0, 0.0, 0.0}}, + {3, cell_kind::spike_source, hash_value("d"), {0, 0.0}, {3.0, 0.0, 0.0}}, + {4, cell_kind::cable, hash_value("e"), {0, 0.2}, {4.0, 0.0, 0.0}}, + {5, cell_kind::cable, hash_value("f"), {5, 0.1}, {5.0, 0.0, 0.0}}, + {6, cell_kind::cable, hash_value("g"), {4, 0.3}, {6.0, 0.0, 0.0}}, + {7, cell_kind::cable, hash_value("h"), {0, 1.0}, {7.0, 0.0, 0.0}}, + {9, cell_kind::cable, hash_value("i"), {0, 0.1}, {12.0, 3.0, 4.0}}, + + {10, cell_kind::cable, hash_value("a"), {0, 0.1}, {12.0, 15.0, 16.0}}, + {10, cell_kind::cable, hash_value("b"), {1, 0.1}, {13.0, 15.0, 16.0}}, + {10, cell_kind::cable, hash_value("c"), {1, 0.5}, {14.0, 15.0, 16.0}}, + {10, cell_kind::cable, hash_value("d"), {1, 1.0}, {15.0, 15.0, 16.0}}, + {10, cell_kind::cable, hash_value("e"), {2, 0.1}, {16.0, 15.0, 16.0}}, + {10, cell_kind::cable, hash_value("f"), {3, 0.1}, {16.0, 16.0, 16.0}}, + {10, cell_kind::cable, hash_value("g"), {4, 0.1}, {12.0, 17.0, 16.0}}, + {10, cell_kind::cable, hash_value("h"), {5, 0.1}, {12.0, 18.0, 16.0}}, + {10, cell_kind::cable, hash_value("i"), {6, 0.1}, {12.0, 19.0, 16.0}}, + + {11, cell_kind::cable, hash_value("abcd"), {0, 0.1}, {-2.0, -5.0, 3.0}}, + {11, cell_kind::cable, hash_value("cabd"), {1, 0.2}, {-2.1, -5.0, 3.0}}, + {11, cell_kind::cable, hash_value("cbad"), {1, 0.3}, {-2.2, -5.0, 3.0}}, + {11, cell_kind::cable, hash_value("acbd"), {1, 1.0}, {-2.3, -5.0, 3.0}}, + {11, cell_kind::cable, hash_value("bacd"), {2, 0.2}, {-2.4, -5.0, 3.0}}, + {11, cell_kind::cable, hash_value("bcad"), {3, 0.3}, {-2.5, -5.0, 3.0}}, + {11, cell_kind::cable, hash_value("dabc"), {4, 0.4}, {-2.6, -5.0, 3.0}}, + {11, cell_kind::cable, hash_value("dbca"), {5, 0.5}, {-2.7, -5.0, 3.0}}, + {11, cell_kind::cable, hash_value("dcab"), {6, 0.6}, {-2.8, -5.0, 3.0}}, +}; +} + +TEST(network_selection, all) { + const auto s = thingify(network_selection::all(), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_TRUE(s->select_connection(source, target)); } + } +} + +TEST(network_selection, none) { + const auto s = thingify(network_selection::none(), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_FALSE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_FALSE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_FALSE(s->select_connection(source, target)); } + } +} + +TEST(network_selection, source_cell_kind) { + const auto s = + thingify(network_selection::source_cell_kind(cell_kind::benchmark), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.kind == cell_kind::benchmark, s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.kind == cell_kind::benchmark, s->select_connection(source, target)); + } + } +} + +TEST(network_selection, target_cell_kind) { + const auto s = + thingify(network_selection::target_cell_kind(cell_kind::benchmark), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.kind == cell_kind::benchmark, s->select_target(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(target.kind == cell_kind::benchmark, s->select_connection(source, target)); + } + } +} + +TEST(network_selection, source_label) { + const auto s = thingify(network_selection::source_label({"b", "e"}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ(site.label == hash_value("b") || site.label == hash_value("e"), + s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.label == hash_value("b") || source.label == hash_value("e"), + s->select_connection(source, target)); + } + } +} + +TEST(network_selection, target_label) { + const auto s = thingify(network_selection::target_label({"b", "e"}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ(site.label == hash_value("b") || site.label == hash_value("e"), + s->select_target(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(target.label == hash_value("b") || target.label == hash_value("e"), + s->select_connection(source, target)); + } + } +} + +TEST(network_selection, source_cell_vec) { + const auto s = thingify(network_selection::source_cell({{1, 5}}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 1 || site.gid == 5, s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == 1 || source.gid == 5, s->select_connection(source, target)); + } + } +} + +TEST(network_selection, target_cell_vec) { + const auto s = thingify(network_selection::target_cell({{1, 5}}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 1 || site.gid == 5, s->select_target(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(target.gid == 1 || target.gid == 5, s->select_connection(source, target)); + } + } +} + +TEST(network_selection, source_cell_range) { + const auto s = + thingify(network_selection::source_cell(gid_range(1, 6, 4)), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 1 || site.gid == 5, s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == 1 || source.gid == 5, s->select_connection(source, target)); + } + } +} + +TEST(network_selection, target_cell_range) { + const auto s = + thingify(network_selection::target_cell(gid_range(1, 6, 4)), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 1 || site.gid == 5, s->select_target(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(target.gid == 1 || target.gid == 5, s->select_connection(source, target)); + } + } +} + +TEST(network_selection, chain) { + const auto s = thingify(network_selection::chain({{0, 2, 5}}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 0 || site.gid == 2, s->select_source(site.kind, site.gid, site.label)); + EXPECT_EQ( + site.gid == 2 || site.gid == 5, s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ((source.gid == 0 && target.gid == 2) || (source.gid == 2 && target.gid == 5), + s->select_connection(source, target)); + } + } +} + +TEST(network_selection, chain_range) { + const auto s = thingify(network_selection::chain({gid_range(1, 8, 3)}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 1 || site.gid == 4, s->select_source(site.kind, site.gid, site.label)); + EXPECT_EQ( + site.gid == 4 || site.gid == 7, s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ((source.gid == 1 && target.gid == 4) || (source.gid == 4 && target.gid == 7), + s->select_connection(source, target)); + } + } +} + +TEST(network_selection, chain_range_reverse) { + const auto s = + thingify(network_selection::chain_reverse({gid_range(1, 8, 3)}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 7 || site.gid == 4, s->select_source(site.kind, site.gid, site.label)); + EXPECT_EQ( + site.gid == 4 || site.gid == 1, s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ((source.gid == 7 && target.gid == 4) || (source.gid == 4 && target.gid == 1), + s->select_connection(source, target)); + } + } +} + +TEST(network_selection, inter_cell) { + const auto s = thingify(network_selection::inter_cell(), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid != target.gid, s->select_connection(source, target)); + } + } +} + +TEST(network_selection, named) { + network_label_dict dict; + dict.set("mysel", network_selection::inter_cell()); + const auto s = thingify(network_selection::named("mysel"), dict); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid != target.gid, s->select_connection(source, target)); + } + } +} + +TEST(network_selection, intersect) { + const auto s = thingify(network_selection::intersect(network_selection::source_cell({1}), + network_selection::target_cell({2})), + network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ(site.gid == 1, s->select_source(site.kind, site.gid, site.label)); + EXPECT_EQ(site.gid == 2, s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == 1 && target.gid == 2, s->select_connection(source, target)); + } + } +} + +TEST(network_selection, join) { + const auto s = thingify( + network_selection::join(network_selection::intersect(network_selection::source_cell({1}), + network_selection::target_cell({2})), + network_selection::intersect( + network_selection::source_cell({4}), network_selection::target_cell({5}))), + network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 1 || site.gid == 4, s->select_source(site.kind, site.gid, site.label)); + EXPECT_EQ( + site.gid == 2 || site.gid == 5, s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ((source.gid == 1 && target.gid == 2) || (source.gid == 4 && target.gid == 5), + s->select_connection(source, target)); + } + } +} + +TEST(network_selection, difference) { + const auto s = + thingify(network_selection::difference(network_selection::source_cell({{0, 1, 2}}), + network_selection::source_cell({{1, 3}})), + network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ(site.gid == 0 || site.gid == 1 || site.gid == 2, + s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == 0 || source.gid == 2, s->select_connection(source, target)); + } + } +} + +TEST(network_selection, symmetric_difference) { + const auto s = thingify( + network_selection::symmetric_difference( + network_selection::source_cell({{0, 1, 2}}), network_selection::source_cell({{1, 3}})), + network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ(site.gid == 0 || site.gid == 1 || site.gid == 2 || site.gid == 3, + s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == 0 || source.gid == 2 || source.gid == 3, + s->select_connection(source, target)); + } + } +} + +TEST(network_selection, complement) { + const auto s = thingify( + network_selection::complement(network_selection::inter_cell()), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == target.gid, s->select_connection(source, target)); + } + } +} + +TEST(network_selection, random_p_1) { + const auto s = thingify(network_selection::random(42, 1.0), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_TRUE(s->select_connection(source, target)); } + } +} + +TEST(network_selection, random_p_0) { + const auto s = thingify(network_selection::random(42, 0.0), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_FALSE(s->select_connection(source, target)); } + } +} + +TEST(network_selection, random_seed) { + const auto s1 = thingify(network_selection::random(42, 0.5), network_label_dict()); + const auto s2 = thingify(network_selection::random(4592304, 0.5), network_label_dict()); + + bool all_eq = true; + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + all_eq &= + (s1->select_connection(source, target) == s2->select_connection(source, target)); + } + } + EXPECT_FALSE(all_eq); +} + +TEST(network_selection, random_reproducibility) { + const auto s = thingify(network_selection::random(42, 0.5), network_label_dict()); + + std::vector sites = { + {0, cell_kind::cable, hash_value("a"), {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, cell_kind::cable, hash_value("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, cell_kind::benchmark, hash_value("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, + }; + std::vector ref = {0, 1, 1, 0, 1, 1, 1, 1, 1}; + + std::size_t i = 0; + for (const auto& source: sites) { + for (const auto& target: sites) { + EXPECT_EQ(ref.at(i), s->select_connection(source, target)); + ++i; + } + }; +} + +TEST(network_selection, custom) { + auto inter_cell_func = [](const network_site_info& source, const network_site_info& target) { + return source.gid != target.gid; + }; + const auto s = thingify(network_selection::custom(inter_cell_func), network_label_dict()); + const auto s_ref = thingify(network_selection::inter_cell(), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ( + s->select_connection(source, target), s_ref->select_connection(source, target)); + } + } +} + +TEST(network_selection, distance_lt) { + const double d = 2.1; + const auto s = thingify(network_selection::distance_lt(d), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(distance(source.global_location, target.global_location) < d, + s->select_connection(source, target)); + } + } +} + +TEST(network_selection, distance_gt) { + const double d = 2.1; + const auto s = thingify(network_selection::distance_gt(d), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(distance(source.global_location, target.global_location) > d, + s->select_connection(source, target)); + } + } +} + +TEST(network_value, scalar) { + const auto v = thingify(network_value::scalar(2.0), network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(source, target)); } + } +} + +TEST(network_value, conversion) { + const auto v = thingify(static_cast(2.0), network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(source, target)); } + } +} + +TEST(network_value, named) { + auto dict = network_label_dict(); + dict.set("myval", network_value::scalar(2.0)); + const auto v = thingify(network_value::named("myval"), dict); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(source, target)); } + } +} + +TEST(network_value, distance) { + const auto v = thingify(network_value::distance(), network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ( + distance(source.global_location, target.global_location), v->get(source, target)); + } + } +} + +TEST(network_value, uniform_distribution) { + const auto v = + thingify(network_value::uniform_distribution(42, {-5.0, 3.0}), network_label_dict()); + + double mean = 0.0; + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { mean += v->get(source, target); } + } + + mean /= test_sites.size() * test_sites.size(); + EXPECT_NEAR(mean, -1.0, 1e3); +} + +TEST(network_value, uniform_distribution_reproducibility) { + const auto v = + thingify(network_value::uniform_distribution(42, {-5.0, 3.0}), network_label_dict()); + + std::vector sites = { + {0, cell_kind::cable, hash_value("a"), {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, cell_kind::cable, hash_value("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, cell_kind::benchmark, hash_value("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, + }; + std::vector ref = {0.152358748168055058, + -4.499410763769494004, + 2.208818591778559437, + -4.615620548394118394, + -2.883165846887783879, + -1.227842167463327083, + -3.938243119645829182, + -0.032436439374857962, + -3.392091783670958982}; + + std::size_t i = 0; + for (const auto& source: sites) { + for (const auto& target: sites) { + EXPECT_DOUBLE_EQ(ref.at(i), v->get(source, target)); + ++i; + } + }; +} + +TEST(network_value, normal_distribution) { + const double mean = 5.0; + const double std_dev = 3.0; + const auto v = + thingify(network_value::normal_distribution(42, mean, std_dev), network_label_dict()); + + double sample_mean = 0.0; + double sample_dev = 0.0; + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + const auto result = v->get(source, target); + sample_mean += result; + sample_dev += (result - mean) * (result - mean); + } + } + + sample_mean /= test_sites.size() * test_sites.size(); + sample_dev = std::sqrt(sample_dev / (test_sites.size() * test_sites.size())); + + EXPECT_NEAR(sample_mean, mean, 1e-1); + EXPECT_NEAR(sample_dev, std_dev, 1.5e-1); +} + +TEST(network_value, normal_distribution_reproducibility) { + const double mean = 5.0; + const double std_dev = 3.0; + const auto v = + thingify(network_value::normal_distribution(42, mean, std_dev), network_label_dict()); + + std::vector sites = { + {0, cell_kind::cable, hash_value("a"), {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, cell_kind::cable, hash_value("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, cell_kind::benchmark, hash_value("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, + }; + std::vector ref = {1.719220750899862038, + 3.792930460082558852, + 2.040797389626836544, + 4.690543724504090406, + 6.048018986729678304, + 3.468450499834405676, + 2.641602074572110492, + 4.045110924716160739, + 4.619102745858998382}; + + std::size_t i = 0; + for (const auto& source: sites) { + for (const auto& target: sites) { + EXPECT_DOUBLE_EQ(ref.at(i), v->get(source, target)); + ++i; + } + }; +} + +TEST(network_value, truncated_normal_distribution) { + const double mean = 5.0; + const double std_dev = 3.0; + // symmtric upper / lower bound around mean for easy check of mean + const double lower_bound = 1.0; + const double upper_bound = 9.0; + + const auto v = thingify( + network_value::truncated_normal_distribution(42, mean, std_dev, {lower_bound, upper_bound}), + network_label_dict()); + + double sample_mean = 0.0; + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + const auto result = v->get(source, target); + EXPECT_GT(result, lower_bound); + EXPECT_LE(result, upper_bound); + sample_mean += result; + } + } + + sample_mean /= test_sites.size() * test_sites.size(); + + EXPECT_NEAR(sample_mean, mean, 1e-1); +} + +TEST(network_value, truncated_normal_distribution_reproducibility) { + const double mean = 5.0; + const double std_dev = 3.0; + + const double lower_bound = 2.0; + const double upper_bound = 9.0; + + const auto v = thingify( + network_value::truncated_normal_distribution(42, mean, std_dev, {lower_bound, upper_bound}), + network_label_dict()); + + std::vector sites = { + {0, cell_kind::cable, hash_value("a"), {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, cell_kind::cable, hash_value("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, cell_kind::benchmark, hash_value("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, + }; + std::vector ref = {6.933077952929343368, + 3.822103684855993055, + 3.081517892090295696, + 3.238387276739735476, + 3.739312586647523418, + 8.589787762424691664, + 7.554985027779592244, + 2.924644471896214348, + 3.085597042676768265}; + + std::size_t i = 0; + for (const auto& source: sites) { + for (const auto& target: sites) { + EXPECT_DOUBLE_EQ(ref.at(i), v->get(source, target)); + ++i; + } + }; +} + +TEST(network_value, custom) { + auto func = [](const network_site_info& source, const network_site_info& target) { + return source.global_location.x + target.global_location.x; + }; + + const auto v = thingify(network_value::custom(func), network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ( + v->get(source, target), source.global_location.x + target.global_location.x); + } + } +} + +TEST(network_value, add) { + const auto v = + thingify(network_value::add(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(v->get(source, target), 5.0); } + } +} + +TEST(network_value, sub) { + const auto v = + thingify(network_value::sub(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(v->get(source, target), -1.0); } + } +} + +TEST(network_value, mul) { + const auto v = + thingify(network_value::mul(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(v->get(source, target), 6.0); } + } +} + +TEST(network_value, div) { + const auto v = + thingify(network_value::div(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v->get(source, target), 2.0 / 3.0); + } + } +} + +TEST(network_value, exp) { + const auto v = thingify(network_value::exp(network_value::scalar(2.0)), network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v->get(source, target), std::exp(2.0)); + } + } +} + +TEST(network_value, log) { + const auto v = thingify(network_value::log(network_value::scalar(2.0)), network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v->get(source, target), std::log(2.0)); + } + } +} + +TEST(network_value, min) { + const auto v1 = + thingify(network_value::min(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + const auto v2 = + thingify(network_value::min(network_value::scalar(3.0), network_value::scalar(2.0)), + network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v1->get(source, target), 2.0); + EXPECT_DOUBLE_EQ(v2->get(source, target), 2.0); + } + } +} + +TEST(network_value, max) { + const auto v1 = + thingify(network_value::max(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + const auto v2 = + thingify(network_value::max(network_value::scalar(3.0), network_value::scalar(2.0)), + network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v1->get(source, target), 3.0); + EXPECT_DOUBLE_EQ(v2->get(source, target), 3.0); + } + } +} + +TEST(network_value, if_else) { + const auto v1 = network_value::scalar(2.0); + const auto v2 = network_value::scalar(3.0); + + const auto s = network_selection::inter_cell(); + + const auto v = thingify(network_value::if_else(s, v1, v2), network_label_dict()); + + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v->get(source, target), source.gid != target.gid ? 2.0 : 3.0); + } + } +} diff --git a/test/unit/test_s_expr.cpp b/test/unit/test_s_expr.cpp index 17e9dc3df2..f08607ce5e 100644 --- a/test/unit/test_s_expr.cpp +++ b/test/unit/test_s_expr.cpp @@ -8,10 +8,12 @@ #include #include #include +#include -#include #include +#include #include +#include #include "parse_s_expr.hpp" #include "util/strprintf.hpp" @@ -172,7 +174,7 @@ TEST(s_expr, iterate) { template std::string round_trip_label(const char* in) { if (auto x = parse_label_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); @@ -181,7 +183,7 @@ std::string round_trip_label(const char* in) { std::string round_trip_cv(const char* in) { if (auto x = parse_cv_policy_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); @@ -190,7 +192,7 @@ std::string round_trip_cv(const char* in) { std::string round_trip_region(const char* in) { if (auto x = parse_region_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); @@ -199,7 +201,7 @@ std::string round_trip_region(const char* in) { std::string round_trip_locset(const char* in) { if (auto x = parse_locset_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); @@ -208,7 +210,25 @@ std::string round_trip_locset(const char* in) { std::string round_trip_iexpr(const char* in) { if (auto x = parse_iexpr_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); + } + else { + return x.error().what(); + } +} + +std::string round_trip_network_selection(const char* in) { + if (auto x = parse_network_selection_expression(in)) { + return util::to_string(std::any_cast(*x)); + } + else { + return x.error().what(); + } +} + +std::string round_trip_network_value(const char* in) { + if (auto x = parse_network_value_expression(in)) { + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); @@ -336,9 +356,122 @@ TEST(iexpr, round_tripping) { round_trip_label("(pi)")); } +TEST(network_selection, round_tripping) { + auto network_literals = { + "(all)", + "(none)", + "(inter-cell)", + "(network-selection \"abc\")", + "(intersect (all) (none))", + "(join (all) (none))", + "(symmetric-difference (all) (none))", + "(difference (all) (none))", + "(complement (all))", + "(source-cell-kind (cable-cell))", + "(source-cell-kind (lif-cell))", + "(source-cell-kind (benchmark-cell))", + "(source-cell-kind (spike-source-cell))", + "(target-cell-kind (cable-cell))", + "(target-cell-kind (lif-cell))", + "(target-cell-kind (benchmark-cell))", + "(target-cell-kind (spike-source-cell))", + "(source-label \"abc\")", + "(source-label \"abc\" \"def\")", + "(source-label \"abc\" \"def\" \"ghi\")", + "(target-label \"abc\")", + "(target-label \"abc\" \"def\")", + "(target-label \"abc\" \"def\" \"ghi\")", + "(source-cell 0 1 3 15)", + "(source-cell (gid-range 4 8 2))", + "(target-cell 0 1 3 15)", + "(target-cell (gid-range 4 8 2))", + "(chain 3 1 0 5 7 6)", // order should be preserved + "(chain (gid-range 2 14 3))", + "(chain-reverse (gid-range 2 14 3))", + "(random 42 (scalar 0.1))", + "(random 42 (normal-distribution 43 0.5 0.1))", + "(distance-lt 0.5)", + "(distance-gt 0.5)", + }; + for (auto l: network_literals) { + EXPECT_EQ(l, round_trip_network_selection(l)); + } + + // test order for more than two arguments + EXPECT_EQ("(join (join (join (all) (none)) (inter-cell)) (source-cell 0))", + round_trip_network_selection("(join (all) (none) (inter-cell) (source-cell 0))")); + EXPECT_EQ("(intersect (intersect (intersect (all) (none)) (inter-cell)) (source-cell 0))", + round_trip_network_selection("(intersect (all) (none) (inter-cell) (source-cell 0))")); +} + + +TEST(network_value, round_tripping) { + auto network_literals = { + "(scalar 1.3)", + "(distance 1.3)", + "(network-value \"abc\")", + "(uniform-distribution 42 0 0.8)", + "(normal-distribution 42 0.5 0.1)", + "(truncated-normal-distribution 42 0.5 0.1 0.3 0.7)", + "(log (scalar 1.3))", + "(exp (scalar 1.3))", + "(if-else (inter-cell) (scalar 5.1) (log (scalar 1.3)))", + }; + + for (auto l: network_literals) { + EXPECT_EQ(l, round_trip_network_value(l)); + } + + EXPECT_EQ("(log (scalar 1.3))", round_trip_network_value("(log 1.3)")); + EXPECT_EQ("(exp (scalar 1.3))", round_trip_network_value("(exp 1.3)")); + + EXPECT_EQ( + "(add (scalar -2.1) (scalar 3.1))", round_trip_network_value("(add -2.1 (scalar 3.1))")); + EXPECT_EQ("(add (add (add (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(add -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); + + EXPECT_EQ( + "(sub (scalar -2.1) (scalar 3.1))", round_trip_network_value("(sub -2.1 (scalar 3.1))")); + EXPECT_EQ("(sub (sub (sub (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(sub -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); + + EXPECT_EQ( + "(mul (scalar -2.1) (scalar 3.1))", round_trip_network_value("(mul -2.1 (scalar 3.1))")); + EXPECT_EQ("(mul (mul (mul (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(mul -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); + + EXPECT_EQ( + "(div (scalar -2.1) (scalar 3.1))", round_trip_network_value("(div -2.1 (scalar 3.1))")); + EXPECT_EQ("(div (div (div (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(div -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); + + EXPECT_EQ( + "(min (scalar -2.1) (scalar 3.1))", round_trip_network_value("(min -2.1 (scalar 3.1))")); + EXPECT_EQ("(min (min (min (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(min -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); + + EXPECT_EQ( + "(max (scalar -2.1) (scalar 3.1))", round_trip_network_value("(max -2.1 (scalar 3.1))")); + EXPECT_EQ("(max (max (max (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(max -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); +} + TEST(regloc, round_tripping) { EXPECT_EQ("(cable 3 0 1)", round_trip_label("(branch 3)")); - EXPECT_EQ("(intersect (tag 1) (intersect (tag 2) (tag 3)))", round_trip_label("(intersect (tag 1) (tag 2) (tag 3))")); + EXPECT_EQ("(intersect (tag 1) (intersect (tag 2) (tag 3)))", + round_trip_label("(intersect (tag 1) (tag 2) (tag 3))")); auto region_literals = { "(cable 2 0.1 0.4)", "(region \"foo\")", diff --git a/test/unit/test_spatial_tree.cpp b/test/unit/test_spatial_tree.cpp new file mode 100644 index 0000000000..f32d9a04df --- /dev/null +++ b/test/unit/test_spatial_tree.cpp @@ -0,0 +1,155 @@ +#include + +#include + +#include "util/spatial_tree.hpp" + +#include +#include +#include +#include +#include +#include + +using namespace arb; + +namespace { + +template +struct data_point { + int id = 0; + std::array point; + + bool operator<(const data_point& p) const { + return id < p.id || (id == p.id && point < p.point); + } +}; + +template +struct bounding_box_data { + bounding_box_data(std::size_t seed, + std::size_t num_points, + std::array box_min, + std::array box_max): + box_min(box_min), + box_max(box_max) { + + std::minstd_rand rand_gen(seed); + + data.reserve(num_points); + for (std::size_t i = 0; i < num_points; ++i) { + data_point p; + p.id = i; + for (std::size_t d = 0; d < DIM; ++d) { + + std::uniform_real_distribution distri(box_min[d], box_max[d]); + p.point[d] = distri(rand_gen); + } + data.emplace_back(p); + } + } + + std::array box_min; + std::array box_max; + std::vector> data; +}; + +class st_test: + public ::testing::TestWithParam< + std::tuple> { +public: + void test_spatial_tree() { + switch (std::get<0>(GetParam())) { + case 1: test_spatial_tree_dim<1>(); break; + case 2: test_spatial_tree_dim<2>(); break; + case 3: test_spatial_tree_dim<3>(); break; + case 4: test_spatial_tree_dim<4>(); break; + case 5: test_spatial_tree_dim<5>(); break; + case 6: test_spatial_tree_dim<6>(); break; + default: ASSERT_TRUE(false); + } + } + +private: + template + void test_spatial_tree_dim() { + std::size_t max_depth = std::get<1>(GetParam()); + std::size_t leaf_size_target = std::get<2>(GetParam()); + std::size_t num_points = std::get<3>(GetParam()); + + std::vector> boxes; + std::array box_min, box_max; + std::vector> data; + box_min.fill(-10.0); + box_max.fill(0.0); + + for (std::size_t i = 0; i < DIM; ++i) { + boxes.emplace_back(1, num_points, box_min, box_max); + data.insert(data.end(), boxes.back().data.begin(), boxes.back().data.end()); + box_min[i] += 20.0; + box_max[i] += 20.0; + } + + spatial_tree, DIM> tree( + max_depth, leaf_size_target, data, [](const data_point& d) { return d.point; }); + + // check box without any points + tree.bounding_box_for_each( + box_min, box_max, [](const data_point& d) { ASSERT_TRUE(false); }); + + // check iteration over full tree + { + std::vector> tree_data; + tree.for_each([&](const data_point& d) { tree_data.emplace_back(d); }); + ASSERT_EQ(data.size(), tree_data.size()); + + std::sort(data.begin(), data.end()); + std::sort(tree_data.begin(), tree_data.end()); + for (std::size_t i = 0; i < data.size(); ++i) { + ASSERT_EQ(data[i].id, tree_data[i].id); + ASSERT_EQ(data[i].point, tree_data[i].point); + } + } + + // check contents within each box + for (auto& box: boxes) { + std::vector> tree_data; + tree.bounding_box_for_each(box.box_min, box.box_max, [&](const data_point& d) { + tree_data.emplace_back(d); + }); + ASSERT_EQ(box.data.size(), tree_data.size()); + + std::sort(tree_data.begin(), tree_data.end()); + std::sort(box.data.begin(), box.data.end()); + + for (std::size_t i = 0; i < box.data.size(); ++i) { + ASSERT_EQ(box.data[i].id, tree_data[i].id); + ASSERT_EQ(box.data[i].point, tree_data[i].point); + } + } + } +}; + +std::string param_type_names( + const ::testing::TestParamInfo>& + info) { + std::stringstream stream; + + stream << "dim_" << std::get<0>(info.param); + stream << "_depth_" << std::get<1>(info.param); + stream << "_leaf_" << std::get<2>(info.param); + stream << "_n_" << std::get<3>(info.param); + + return stream.str(); +} +} // namespace + +TEST_P(st_test, param) { test_spatial_tree(); } + +INSTANTIATE_TEST_SUITE_P(spatial_tree, + st_test, + ::testing::Combine(::testing::Values(1, 2, 3), + ::testing::Values(1, 10, 20), + ::testing::Values(1, 100, 1000), + ::testing::Values(0, 1, 10, 100, 1000, 2000)), + param_type_names);