Skip to content

Commit

Permalink
use hashset for grads
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Nov 14, 2024
1 parent e59f2d1 commit bbf203c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 62 deletions.
4 changes: 2 additions & 2 deletions src/ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ void ggml_hash_set_reset(struct ggml_hash_set * hash_set);
static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key);

// returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted
static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key);
static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key);

// returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key);
Expand All @@ -196,7 +196,7 @@ static inline size_t ggml_hash(const struct ggml_tensor * p) {
return (size_t)(uintptr_t)p >> 4;
}

static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) {
static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key) {
size_t h = ggml_hash(key) % hash_set->size;

// linear probing
Expand Down
10 changes: 6 additions & 4 deletions src/ggml-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,11 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * graph) {
for (int i = 0; i < graph->n_nodes; i++) {
ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i]));
}
for (int i = 0; i < ggml_graph_n_nodes(graph); i++) {
new_graph->grads[i] = map_tensor(tensor_map, ctx, graph->grads[i]);
new_graph->grad_accs[i] = map_tensor(tensor_map, ctx, graph->grad_accs[i]);
for (int i = 0; i < graph->n_nodes; ++i) {
const size_t igrad_src = ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]);
const size_t igrad_dst = ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]);
graph->grads[igrad_dst] = new_graph->grads[igrad_src];
graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src];
}

return new_graph;
Expand Down Expand Up @@ -454,7 +456,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {

for (int i = result->gf->n_nodes-1; i >= 0; --i) {
struct ggml_tensor * node = result->gb_opt->nodes[i];
struct ggml_tensor * grad = result->gb_opt->grads[i];
struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);

if (node->flags & GGML_TENSOR_FLAG_PARAM) {
struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node);
Expand Down
108 changes: 52 additions & 56 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5031,7 +5031,7 @@ static void ggml_hash_map_free(struct hash_map * map) {
static void ggml_add_or_set(
struct ggml_context * ctx,
struct ggml_cgraph * cgraph,
int isrc,
size_t isrc,
struct ggml_tensor * tensor) {
if (cgraph->grads[isrc]) {
cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
Expand All @@ -5044,7 +5044,8 @@ static void ggml_add_or_set(
static void ggml_acc_or_set(
struct ggml_context * ctx,
struct ggml_cgraph * cgraph,
int isrc,
size_t isrc,
struct ggml_tensor * src,
struct ggml_tensor * tensor,
const size_t nb1,
const size_t nb2,
Expand All @@ -5053,7 +5054,7 @@ static void ggml_acc_or_set(
if (cgraph->grads[isrc]) {
cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
} else {
struct ggml_tensor * a_zero = ggml_scale(ctx, cgraph->nodes[isrc], 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
}
ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
Expand All @@ -5062,20 +5063,21 @@ static void ggml_acc_or_set(
static void ggml_add1_or_set(
struct ggml_context * ctx,
struct ggml_cgraph * cgraph,
int isrc,
size_t isrc,
struct ggml_tensor * src,
struct ggml_tensor * tensor) {
if (cgraph->grads[isrc]) {
cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
} else {
cgraph->grads[isrc] = ggml_repeat(ctx, tensor, cgraph->nodes[isrc]);
cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
}
ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
}

static void ggml_sub_or_set(
struct ggml_context * ctx,
struct ggml_cgraph * cgraph,
int isrc,
size_t isrc,
struct ggml_tensor * tensor) {
if (cgraph->grads[isrc]) {
cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
Expand All @@ -5085,19 +5087,10 @@ static void ggml_sub_or_set(
ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
}

static int ggml_find_node(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
for (int i = 0; i < cgraph->n_nodes; ++i) {
if (cgraph->nodes[i] == node) {
return i;
}
}
return -1;
}

static void ggml_compute_backward(
struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
struct ggml_tensor * tensor = cgraph->nodes[i];
struct ggml_tensor * grad = cgraph->grads[i];
struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);

if (!grad) {
return;
Expand All @@ -5106,12 +5099,13 @@ static void ggml_compute_backward(
struct ggml_tensor * src0 = tensor->src[0];
struct ggml_tensor * src1 = tensor->src[1];
struct ggml_tensor * src2 = tensor->src[2];
const int isrc0 = ggml_find_node(cgraph, src0);
const int isrc1 = ggml_find_node(cgraph, src1);
const int isrc2 = ggml_find_node(cgraph, src2);
const bool src0_needs_grads = isrc0 >= 0 && grads_needed[isrc0];
const bool src1_needs_grads = isrc1 >= 0 && grads_needed[isrc1];
const bool src2_needs_grads = isrc2 >= 0 && grads_needed[isrc2];
struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
const size_t isrc0 = ggml_hash_find(hash_set, src0);
const size_t isrc1 = ggml_hash_find(hash_set, src1);
const size_t isrc2 = ggml_hash_find(hash_set, src2);
const bool src0_needs_grads = isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
const bool src1_needs_grads = isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
const bool src2_needs_grads = isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];

switch (tensor->op) {
case GGML_OP_DUP: {
Expand All @@ -5133,7 +5127,7 @@ static void ggml_compute_backward(
} break;
case GGML_OP_ADD1: {
if (src0_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc0, cgraph->grads[i]);
ggml_add_or_set(ctx, cgraph, isrc0, grad);
}
if (src1_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean
Expand Down Expand Up @@ -5211,7 +5205,7 @@ static void ggml_compute_backward(
} break;
case GGML_OP_SUM: {
if (src0_needs_grads) {
ggml_add1_or_set(ctx, cgraph, isrc0, grad);
ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad);
}
} break;
case GGML_OP_SUM_ROWS: {
Expand All @@ -5221,7 +5215,7 @@ static void ggml_compute_backward(
} break;
case GGML_OP_MEAN: {
if (src0_needs_grads) {
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
ggml_add1_or_set(ctx, cgraph, isrc0, src0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
}
} break;
case GGML_OP_REPEAT: {
Expand Down Expand Up @@ -5374,7 +5368,7 @@ static void ggml_compute_backward(
nb3 = (nb3 / n0) * ng;
}

ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);
ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset);
}
} break;
case GGML_OP_PERMUTE: {
Expand Down Expand Up @@ -5531,9 +5525,9 @@ static void ggml_compute_backward(
} break;
}

GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(cgraph->nodes[isrc0], cgraph->grads[isrc0]));
GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(cgraph->nodes[isrc1], cgraph->grads[isrc1]));
GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(cgraph->nodes[isrc2], cgraph->grads[isrc2]));
GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0]));
GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1]));
GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
}

static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
Expand Down Expand Up @@ -5608,9 +5602,10 @@ void ggml_build_backward_expand(

const int n_nodes_f = cgraph->n_nodes;

memset(cgraph->grads, 0, cgraph->size*sizeof(struct ggml_tensor *));
memset(cgraph->grad_accs, 0, cgraph->size*sizeof(struct ggml_tensor *));
bool * grads_needed = calloc(n_nodes_f, sizeof(bool));
const size_t hash_size = ggml_hash_size(2*cgraph->size);
memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *));
memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
bool * grads_needed = calloc(hash_size, sizeof(bool));

{
bool any_params = false;
Expand Down Expand Up @@ -5659,7 +5654,7 @@ void ggml_build_backward_expand(
break;
}
for (int j = 0; j < GGML_MAX_SRC; ++j) {
if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_find_node(cgraph, node->src[j])]) {
if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) {
continue;
}
GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16);
Expand All @@ -5674,11 +5669,12 @@ void ggml_build_backward_expand(
GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);

const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
cgraph->grads[i] = ggml_dup_tensor(ctx_static, node);
cgraph->grad_accs[i] = cgraph->grads[i];
cgraph->grads[igrad] = ggml_dup_tensor(ctx_static, node);
cgraph->grad_accs[igrad] = cgraph->grads[igrad];
}
grads_needed[i] = true;
grads_needed[igrad] = true;
}

for (int i = n_nodes_f - 1; i >= 0; --i) {
Expand All @@ -5705,8 +5701,8 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
if (grads) {
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs
}
incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));

Expand Down Expand Up @@ -5735,8 +5731,8 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;

ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));

Expand Down Expand Up @@ -5797,23 +5793,23 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
dst->nodes[i] = src->nodes[i];
}

if (src->grads) {
GGML_ASSERT(dst->grads != NULL);
for (int i = 0; i < src->n_nodes; ++i) {
dst->grads[i] = src->grads[i];
}
GGML_ASSERT(dst->grad_accs != NULL);
for (int i = 0; i < src->n_nodes; ++i) {
dst->grad_accs[i] = src->grad_accs[i];
}
}

for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
// copy all hashset keys (tensors) that are in use
if (ggml_bitset_get(src->visited_hash_set.used, i)) {
ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
}
}

if (src->grads) {
GGML_ASSERT(dst->grads != NULL);
GGML_ASSERT(dst->grad_accs != NULL);
for (int i = 0; i < src->n_nodes; ++i) {
const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
dst->grads[igrad_dst] = src->grads[igrad_src];
dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
}
}
}

struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
Expand All @@ -5840,7 +5836,7 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {

for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
struct ggml_tensor * grad_acc = cgraph->grad_accs[i];
struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node);

if (node->op == GGML_OP_OPT_STEP_ADAMW) {
// clear momenta
Expand Down Expand Up @@ -5927,13 +5923,13 @@ struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, co
}

struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
const int i = ggml_find_node(cgraph, node);
return i >= 0 ? cgraph->grads[i] : NULL;
const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL;
}

struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
const int i = ggml_find_node(cgraph, node);
return i >= 0 ? cgraph->grad_accs[i] : NULL;
const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL;
}

void ggml_graph_print(const struct ggml_cgraph * cgraph) {
Expand Down

0 comments on commit bbf203c

Please sign in to comment.