Skip to content

Commit

Permalink
Reduce even more GPU allocations (#2394)
Browse files Browse the repository at this point in the history
Bubbling up ion access information from modcc allows us to skip
allocation of `Xi` and `Xo` iff
no mechanism reads those values. 

A minor problem: Sampling may try to touch data that doesn't exist;
however who'd ask for `Xi`
if they never use it? 

Stacking this PR on top of #2393 the memory consumption drops by a
further 5% down to 77%
if the original value.

### Todo
- [x] Guard against sampling  non-existing `Xi`/`Xo`

---------

Co-authored-by: Jannik Luboeinski <33398515+jlubo@users.noreply.github.com>
  • Loading branch information
thorstenhater and jlubo authored Oct 21, 2024
1 parent f302088 commit 9de9174
Show file tree
Hide file tree
Showing 21 changed files with 270 additions and 166 deletions.
47 changes: 17 additions & 30 deletions arbor/backends/gpu/matrix_state_fine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <cstring>

#include <vector>
#include <type_traits>

#include <arbor/common_types.hpp>

Expand Down Expand Up @@ -37,15 +36,11 @@ struct matrix_state_fine {
array rhs; // [nA]

// Required for matrix assembly
array cv_area; // [μm^2]
array cv_capacitance; // [pF]

// Invariant part of the matrix diagonal
array invariant_d; // [μS]

// Solution in unpacked format
array solution_;

// Maximum number of branches in each level per block
unsigned max_branches_per_level;

Expand Down Expand Up @@ -82,16 +77,13 @@ struct matrix_state_fine {
// `solver_format[perm[i]] = external_format[i]`
iarray perm;


matrix_state_fine() = default;

// constructor for fine-grained matrix.
matrix_state_fine(const std::vector<size_type>& p,
const std::vector<size_type>& cell_cv_divs,
const std::vector<value_type>& cap,
const std::vector<value_type>& face_conductance,
const std::vector<value_type>& area)
{
const std::vector<size_type>& cell_cv_divs,
const std::vector<value_type>& cap,
const std::vector<value_type>& face_conductance) {
using util::make_span;
constexpr unsigned npos = unsigned(-1);

Expand Down Expand Up @@ -360,7 +352,6 @@ struct matrix_state_fine {
// cv_capacitance : flat
// invariant_d : flat
// cv_to_cell : flat
// area : flat

// the invariant part of d is stored in in flat form
std::vector<value_type> invariant_d_tmp(matrix_size, 0);
Expand All @@ -386,9 +377,6 @@ struct matrix_state_fine {
// transform u_shuffled values into packed u vector.
flat_to_packed(u_shuffled, u);

// the invariant part of d and cv_area are in flat form
cv_area = memory::make_const_view(area);

// the cv_capacitance can be copied directly because it is
// to be stored in flat format
cv_capacitance = memory::make_const_view(cap);
Expand All @@ -408,19 +396,18 @@ struct matrix_state_fine {
// voltage [mV]
// current density [A/m²]
// conductivity [kS/m²]
void assemble(const T dt, const_view voltage, const_view current, const_view conductivity) {
assemble_matrix_fine(
d.data(),
rhs.data(),
invariant_d.data(),
voltage.data(),
current.data(),
conductivity.data(),
cv_capacitance.data(),
cv_area.data(),
dt,
perm.data(),
size());
void assemble(const T dt, const_view voltage, const_view current, const_view conductivity, const_view area_um2) {
assemble_matrix_fine(d.data(),
rhs.data(),
invariant_d.data(),
voltage.data(),
current.data(),
conductivity.data(),
cv_capacitance.data(),
area_um2.data(),
dt,
perm.data(),
size());
}

void solve(array& to) {
Expand All @@ -441,8 +428,8 @@ struct matrix_state_fine {


void solve(array& voltage,
const T dt, const_view current, const_view conductivity) {
assemble(dt, voltage, current, conductivity);
const T dt, const_view current, const_view conductivity, const_view area_um2) {
assemble(dt, voltage, current, conductivity, area_um2);
solve(voltage);
}

Expand Down
46 changes: 34 additions & 12 deletions arbor/backends/gpu/shared_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,42 @@ ion_state::ion_state(const fvm_ion_config& ion_data,
write_eX_(ion_data.revpot_written),
write_Xo_(ion_data.econc_written),
write_Xi_(ion_data.iconc_written),
write_Xd_(ion_data.is_diffusive),
read_Xo_(ion_data.econc_written || ion_data.econc_read), // ensure that if we have W access, also R access is flagged
read_Xi_(ion_data.iconc_written || ion_data.iconc_read),
node_index_(make_const_view(ion_data.cv)),
iX_(ion_data.cv.size(), NAN),
eX_(ion_data.init_revpot.begin(), ion_data.init_revpot.end()),
Xi_(ion_data.init_iconc.begin(), ion_data.init_iconc.end()),
Xd_(ion_data.cv.size(), NAN),
Xo_(ion_data.init_econc.begin(), ion_data.init_econc.end()),
gX_(ion_data.cv.size(), NAN),
init_Xi_(make_const_view(ion_data.init_iconc)),
init_Xo_(make_const_view(ion_data.init_econc)),
reset_Xi_(make_const_view(ion_data.reset_iconc)),
reset_Xo_(make_const_view(ion_data.reset_econc)),
init_eX_(make_const_view(ion_data.init_revpot)),
charge(1u, static_cast<arb_value_type>(ion_data.charge)),
solver(std::move(ptr)) {
arb_assert(node_index_.size()==init_Xi_.size());
arb_assert(node_index_.size()==init_Xo_.size());
arb_assert(node_index_.size()==init_eX_.size());
// We don't need to allocate these if we never use them...
if (read_Xi_) {
Xi_ = make_const_view(ion_data.init_iconc);
}
if (read_Xo_) {
Xo_ = make_const_view(ion_data.init_econc);
}
if (write_Xi_ || write_Xd_) {
// ... but this is used by Xd and Xi!
reset_Xi_ = make_const_view(ion_data.reset_iconc);
}
if (write_Xi_) {
init_Xi_ = make_const_view(ion_data.init_iconc);
arb_assert(node_index_.size()==init_Xi_.size());
}
if (write_Xo_) {
init_Xo_ = make_const_view(ion_data.init_econc);
reset_Xo_ = make_const_view(ion_data.reset_econc);
arb_assert(node_index_.size()==init_Xo_.size());
}
if (write_eX_) {
init_eX_ = make_const_view(ion_data.init_revpot);
arb_assert(node_index_.size()==init_eX_.size());
}
if (write_Xd_) {
Xd_ = array(ion_data.cv.size(), NAN);
}
}

void ion_state::init_concentration() {
Expand All @@ -81,10 +100,13 @@ void ion_state::zero_current() {

void ion_state::reset() {
zero_current();
memory::copy(reset_Xi_, Xd_);
if (write_Xi_) memory::copy(reset_Xi_, Xi_);
if (write_Xo_) memory::copy(reset_Xo_, Xo_);
if (write_eX_) memory::copy(init_eX_, eX_);
// This goes _last_ or at least after Xi since we might have removed reset_Xi
// when Xi is constant. Thus conditionally resetting Xi first and then copying
// Xi -> Xd is save in all cases.
if (write_Xd_) memory::copy(reset_Xi_, Xd_);
}

// istim_state methods:
Expand Down
11 changes: 8 additions & 3 deletions arbor/backends/gpu/shared_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ struct ARB_ARBOR_API ion_state {
using solver_type = arb::gpu::diffusion_state<arb_value_type, arb_index_type>;
using solver_ptr = std::unique_ptr<solver_type>;

bool write_eX_; // is eX written?
bool write_Xo_; // is Xo written?
bool write_Xi_; // is Xi written?
bool write_eX_:1; // is eX written?
bool write_Xo_:1; // is Xo written?
bool write_Xi_:1; // is Xi written?
bool write_Xd_:1; // is Xd written?
bool read_eX_:1; // is eX read?
bool read_Xo_:1; // is Xo read?
bool read_Xi_:1; // is Xi read?
bool read_Xd_:1; // is Xd read?

iarray node_index_; // Instance to CV map.
array iX_; // (A/m²) current density
Expand Down
7 changes: 2 additions & 5 deletions arbor/backends/multicore/cable_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ struct cable_solver {
array d; // [μS]
array u; // [μS]
array cv_capacitance; // [pF]
array cv_area; // [μm^2]
array invariant_d; // [μS] invariant part of matrix diagonal

cable_solver() = default;
Expand All @@ -36,13 +35,11 @@ struct cable_solver {
cable_solver(const std::vector<index_type>& p,
const std::vector<index_type>& cell_cv_divs,
const std::vector<value_type>& cap,
const std::vector<value_type>& cond,
const std::vector<value_type>& area):
const std::vector<value_type>& cond):
parent_index(p.begin(), p.end()),
cell_cv_divs(cell_cv_divs.begin(), cell_cv_divs.end()),
d(size(), 0), u(size(), 0),
cv_capacitance(cap.begin(), cap.end()),
cv_area(area.begin(), area.end()),
invariant_d(size(), 0)
{
// Sanity check
Expand All @@ -67,7 +64,7 @@ struct cable_solver {
// * expects the voltage from its first argument
// * will likewise overwrite the first argument with the solction
template<typename T>
void solve(T& rhs, const value_type dt, const_view current, const_view conductivity) {
void solve(T& rhs, const value_type dt, const_view current, const_view conductivity, const_view cv_area) {
value_type * const ARB_NO_ALIAS d_ = d.data();
value_type * const ARB_NO_ALIAS r_ = rhs.data();

Expand Down
49 changes: 35 additions & 14 deletions arbor/backends/multicore/shared_state.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <algorithm>
#include <cmath>
#include <iostream>
#include <string>
#include <unordered_map>
#include <utility>
Expand Down Expand Up @@ -58,24 +57,43 @@ ion_state::ion_state(const fvm_ion_config& ion_data,
write_eX_(ion_data.revpot_written),
write_Xo_(ion_data.econc_written),
write_Xi_(ion_data.iconc_written),
write_Xd_(ion_data.is_diffusive),
read_Xo_(ion_data.econc_written || ion_data.econc_read), // ensure that if we have W access, also R access is flagged
read_Xi_(ion_data.iconc_written || ion_data.iconc_read),
read_Xd_(ion_data.is_diffusive),
node_index_(ion_data.cv.begin(), ion_data.cv.end(), pad(alignment)),
iX_(ion_data.cv.size(), NAN, pad(alignment)),
eX_(ion_data.init_revpot.begin(), ion_data.init_revpot.end(), pad(alignment)),
Xi_(ion_data.init_iconc.begin(), ion_data.init_iconc.end(), pad(alignment)),
Xd_(ion_data.reset_iconc.begin(), ion_data.reset_iconc.end(), pad(alignment)),
Xo_(ion_data.init_econc.begin(), ion_data.init_econc.end(), pad(alignment)),
gX_(ion_data.cv.size(), NAN, pad(alignment)),
init_Xi_(ion_data.init_iconc.begin(), ion_data.init_iconc.end(), pad(alignment)),
init_Xo_(ion_data.init_econc.begin(), ion_data.init_econc.end(), pad(alignment)),
reset_Xi_(ion_data.reset_iconc.begin(), ion_data.reset_iconc.end(), pad(alignment)),
reset_Xo_(ion_data.reset_econc.begin(), ion_data.reset_econc.end(), pad(alignment)),
init_eX_(ion_data.init_revpot.begin(), ion_data.init_revpot.end(), pad(alignment)),
charge(1u, ion_data.charge, pad(alignment)),
solver(std::move(ptr)) {
arb_assert(node_index_.size()==init_Xi_.size());
arb_assert(node_index_.size()==init_Xo_.size());
arb_assert(node_index_.size()==eX_.size());
arb_assert(node_index_.size()==init_eX_.size());
// We don't need to allocate these if we never use them...
if (read_Xi_) {
Xi_ = {ion_data.init_iconc.begin(), ion_data.init_iconc.end(), pad(alignment)};
}
if (read_Xo_) {
Xo_ = {ion_data.init_econc.begin(), ion_data.init_econc.end(), pad(alignment)};
}
if (write_Xi_ || write_Xd_) {
// ... but this is used by Xd and Xi!
reset_Xi_ = {ion_data.reset_iconc.begin(), ion_data.reset_iconc.end(), pad(alignment)};
}
if (write_Xi_) {
init_Xi_ = {ion_data.init_iconc.begin(), ion_data.init_iconc.end(), pad(alignment)};
arb_assert(node_index_.size()==init_Xi_.size());
}
if (write_Xo_) {
init_Xo_ = {ion_data.init_econc.begin(), ion_data.init_econc.end(), pad(alignment)};
reset_Xo_ = {ion_data.reset_econc.begin(), ion_data.reset_econc.end(), pad(alignment)};
arb_assert(node_index_.size()==init_Xo_.size());
}
if (write_eX_) {
init_eX_ = {ion_data.init_revpot.begin(), ion_data.init_revpot.end(), pad(alignment)};
arb_assert(node_index_.size()==init_eX_.size());
}
if (read_Xd_) {
Xd_ = {ion_data.reset_iconc.begin(), ion_data.reset_iconc.end(), pad(alignment)};
}
}

void ion_state::init_concentration() {
Expand All @@ -91,10 +109,13 @@ void ion_state::zero_current() {

void ion_state::reset() {
zero_current();
std::copy(reset_Xi_.begin(), reset_Xi_.end(), Xd_.begin());
if (write_Xi_) std::copy(reset_Xi_.begin(), reset_Xi_.end(), Xi_.begin());
if (write_Xo_) std::copy(reset_Xo_.begin(), reset_Xo_.end(), Xo_.begin());
if (write_eX_) std::copy(init_eX_.begin(), init_eX_.end(), eX_.begin());
// This goes _last_ or at least after Xi since we might have removed reset_Xi
// when Xi is constant. Thus conditionally resetting Xi first and then copying
// Xi -> Xd is safe in all cases.
if (write_Xd_) std::copy(Xi_.begin(), Xi_.end(), Xd_.begin());
}

// istim_state methods:
Expand Down
11 changes: 8 additions & 3 deletions arbor/backends/multicore/shared_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ struct ARB_ARBOR_API ion_state {

unsigned alignment = 1; // Alignment and padding multiple.

bool write_eX_; // is eX written?
bool write_Xo_; // is Xo written?
bool write_Xi_; // is Xi written?
bool write_eX_:1; // is eX written?
bool write_Xo_:1; // is Xo written?
bool write_Xi_:1; // is Xi written?
bool write_Xd_:1; // is Xd written?
bool read_eX_:1; // is eX read?
bool read_Xo_:1; // is Xo read?
bool read_Xi_:1; // is Xi read?
bool read_Xd_:1; // is Xd read?

iarray node_index_; // Instance to CV map.
array iX_; // (A/m²) current density
Expand Down
7 changes: 4 additions & 3 deletions arbor/backends/shared_state_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
#include "backends/common_types.hpp"
#include "fvm_layout.hpp"

#include "event_lane.hpp"
#include "util/rangeutil.hpp"
#include "timestep_range.hpp"
#include "event_lane.hpp"

namespace arb {

Expand Down Expand Up @@ -48,7 +49,7 @@ struct shared_state_base {

void configure_solver(const fvm_cv_discretization& disc) {
auto d = static_cast<D*>(this);
d->solver = {disc.geometry.cv_parent, disc.geometry.cell_cv_divs, disc.cv_capacitance, disc.face_conductance, disc.cv_area};
d->solver = {disc.geometry.cv_parent, disc.geometry.cell_cv_divs, disc.cv_capacitance, disc.face_conductance};
}

void add_ion(const std::string& ion_name,
Expand Down Expand Up @@ -134,7 +135,7 @@ struct shared_state_base {

void integrate_cable_state() {
auto d = static_cast<D*>(this);
d->solver.solve(d->voltage, d->dt, d->current_density, d->conductivity);
d->solver.solve(d->voltage, d->dt, d->current_density, d->conductivity, d->area_um2);
for (auto& [ion, data]: d->ion_data) {
if (data.solver) {
data.solver->solve(data.Xd_,
Expand Down
Loading

0 comments on commit 9de9174

Please sign in to comment.