diff --git a/src/ggml-impl.h b/src/ggml-impl.h index b53be5c4dd..f1f0783731 100644 --- a/src/ggml-impl.h +++ b/src/ggml-impl.h @@ -182,7 +182,7 @@ void ggml_hash_set_reset(struct ggml_hash_set * hash_set); static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key); // returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted -static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key); +static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key); // returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key); @@ -196,7 +196,7 @@ static inline size_t ggml_hash(const struct ggml_tensor * p) { return (size_t)(uintptr_t)p >> 4; } -static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) { +static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key) { size_t h = ggml_hash(key) % hash_set->size; // linear probing diff --git a/src/ggml-opt.cpp b/src/ggml-opt.cpp index 9279a5ff8e..c1326bc430 100644 --- a/src/ggml-opt.cpp +++ b/src/ggml-opt.cpp @@ -253,9 +253,11 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * graph) { for (int i = 0; i < graph->n_nodes; i++) { ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i])); } - for (int i = 0; i < ggml_graph_n_nodes(graph); i++) { - new_graph->grads[i] = map_tensor(tensor_map, ctx, graph->grads[i]); - new_graph->grad_accs[i] = map_tensor(tensor_map, ctx, graph->grad_accs[i]); + for (int i = 0; i < graph->n_nodes; ++i) { + const size_t igrad_src = ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]); + const size_t igrad_dst = ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]); + graph->grads[igrad_dst] = new_graph->grads[igrad_src]; + graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src]; } return new_graph; @@ -454,7 +456,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { for (int i = result->gf->n_nodes-1; i >= 0; --i) { struct ggml_tensor * node = result->gb_opt->nodes[i]; - struct ggml_tensor * grad = result->gb_opt->grads[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node); if (node->flags & GGML_TENSOR_FLAG_PARAM) { struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node); diff --git a/src/ggml.c b/src/ggml.c index aba5678388..0c28345394 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -5031,7 +5031,7 @@ static void ggml_hash_map_free(struct hash_map * map) { static void ggml_add_or_set( struct ggml_context * ctx, struct ggml_cgraph * cgraph, - int isrc, + size_t isrc, struct ggml_tensor * tensor) { if (cgraph->grads[isrc]) { cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); @@ -5044,7 +5044,8 @@ static void ggml_add_or_set( static void ggml_acc_or_set( struct ggml_context * ctx, struct ggml_cgraph * cgraph, - int isrc, + size_t isrc, + struct ggml_tensor * src, struct ggml_tensor * tensor, const size_t nb1, const size_t nb2, @@ -5053,7 +5054,7 @@ static void ggml_acc_or_set( if (cgraph->grads[isrc]) { cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]); } else { - struct ggml_tensor * a_zero = ggml_scale(ctx, cgraph->nodes[isrc], 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN + struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false); } ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); @@ -5062,12 +5063,13 @@ static void ggml_acc_or_set( static void ggml_add1_or_set( struct ggml_context * ctx, struct ggml_cgraph * cgraph, - int isrc, + size_t isrc, + struct ggml_tensor * src, struct ggml_tensor * tensor) { if (cgraph->grads[isrc]) { cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); } else { - cgraph->grads[isrc] = ggml_repeat(ctx, tensor, cgraph->nodes[isrc]); + cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src); } ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } @@ -5075,7 +5077,7 @@ static void ggml_add1_or_set( static void ggml_sub_or_set( struct ggml_context * ctx, struct ggml_cgraph * cgraph, - int isrc, + size_t isrc, struct ggml_tensor * tensor) { if (cgraph->grads[isrc]) { cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); @@ -5085,19 +5087,10 @@ static void ggml_sub_or_set( ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static int ggml_find_node(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { - for (int i = 0; i < cgraph->n_nodes; ++i) { - if (cgraph->nodes[i] == node) { - return i; - } - } - return -1; -} - static void ggml_compute_backward( struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) { struct ggml_tensor * tensor = cgraph->nodes[i]; - struct ggml_tensor * grad = cgraph->grads[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor); if (!grad) { return; @@ -5106,12 +5099,13 @@ static void ggml_compute_backward( struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; struct ggml_tensor * src2 = tensor->src[2]; - const int isrc0 = ggml_find_node(cgraph, src0); - const int isrc1 = ggml_find_node(cgraph, src1); - const int isrc2 = ggml_find_node(cgraph, src2); - const bool src0_needs_grads = isrc0 >= 0 && grads_needed[isrc0]; - const bool src1_needs_grads = isrc1 >= 0 && grads_needed[isrc1]; - const bool src2_needs_grads = isrc2 >= 0 && grads_needed[isrc2]; + struct ggml_hash_set * hash_set = &cgraph->visited_hash_set; + const size_t isrc0 = ggml_hash_find(hash_set, src0); + const size_t isrc1 = ggml_hash_find(hash_set, src1); + const size_t isrc2 = ggml_hash_find(hash_set, src2); + const bool src0_needs_grads = isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0]; + const bool src1_needs_grads = isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1]; + const bool src2_needs_grads = isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2]; switch (tensor->op) { case GGML_OP_DUP: { @@ -5133,7 +5127,7 @@ static void ggml_compute_backward( } break; case GGML_OP_ADD1: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, cgraph->grads[i]); + ggml_add_or_set(ctx, cgraph, isrc0, grad); } if (src1_needs_grads) { ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean @@ -5211,7 +5205,7 @@ static void ggml_compute_backward( } break; case GGML_OP_SUM: { if (src0_needs_grads) { - ggml_add1_or_set(ctx, cgraph, isrc0, grad); + ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad); } } break; case GGML_OP_SUM_ROWS: { @@ -5221,7 +5215,7 @@ static void ggml_compute_backward( } break; case GGML_OP_MEAN: { if (src0_needs_grads) { - ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); + ggml_add1_or_set(ctx, cgraph, isrc0, src0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); } } break; case GGML_OP_REPEAT: { @@ -5374,7 +5368,7 @@ static void ggml_compute_backward( nb3 = (nb3 / n0) * ng; } - ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset); + ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset); } } break; case GGML_OP_PERMUTE: { @@ -5531,9 +5525,9 @@ static void ggml_compute_backward( } break; } - GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(cgraph->nodes[isrc0], cgraph->grads[isrc0])); - GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(cgraph->nodes[isrc1], cgraph->grads[isrc1])); - GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(cgraph->nodes[isrc2], cgraph->grads[isrc2])); + GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0])); + GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1])); + GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { @@ -5608,9 +5602,10 @@ void ggml_build_backward_expand( const int n_nodes_f = cgraph->n_nodes; - memset(cgraph->grads, 0, cgraph->size*sizeof(struct ggml_tensor *)); - memset(cgraph->grad_accs, 0, cgraph->size*sizeof(struct ggml_tensor *)); - bool * grads_needed = calloc(n_nodes_f, sizeof(bool)); + const size_t hash_size = ggml_hash_size(2*cgraph->size); + memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *)); + memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *)); + bool * grads_needed = calloc(hash_size, sizeof(bool)); { bool any_params = false; @@ -5659,7 +5654,7 @@ void ggml_build_backward_expand( break; } for (int j = 0; j < GGML_MAX_SRC; ++j) { - if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_find_node(cgraph, node->src[j])]) { + if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) { continue; } GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16); @@ -5674,11 +5669,12 @@ void ggml_build_backward_expand( GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE); + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { - cgraph->grads[i] = ggml_dup_tensor(ctx_static, node); - cgraph->grad_accs[i] = cgraph->grads[i]; + cgraph->grads[igrad] = ggml_dup_tensor(ctx_static, node); + cgraph->grad_accs[igrad] = cgraph->grads[igrad]; } - grads_needed[i] = true; + grads_needed[igrad] = true; } for (int i = n_nodes_f - 1; i >= 0; --i) { @@ -5705,8 +5701,8 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) { incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys if (grads) { - incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads - incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs } incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); @@ -5735,8 +5731,8 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; - struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); @@ -5797,23 +5793,23 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { dst->nodes[i] = src->nodes[i]; } - if (src->grads) { - GGML_ASSERT(dst->grads != NULL); - for (int i = 0; i < src->n_nodes; ++i) { - dst->grads[i] = src->grads[i]; - } - GGML_ASSERT(dst->grad_accs != NULL); - for (int i = 0; i < src->n_nodes; ++i) { - dst->grad_accs[i] = src->grad_accs[i]; - } - } - for (size_t i = 0; i < src->visited_hash_set.size; ++i) { // copy all hashset keys (tensors) that are in use if (ggml_bitset_get(src->visited_hash_set.used, i)) { ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); } } + + if (src->grads) { + GGML_ASSERT(dst->grads != NULL); + GGML_ASSERT(dst->grad_accs != NULL); + for (int i = 0; i < src->n_nodes; ++i) { + const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]); + const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]); + dst->grads[igrad_dst] = src->grads[igrad_src]; + dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src]; + } + } } struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { @@ -5840,7 +5836,7 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) { for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; - struct ggml_tensor * grad_acc = cgraph->grad_accs[i]; + struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node); if (node->op == GGML_OP_OPT_STEP_ADAMW) { // clear momenta @@ -5927,13 +5923,13 @@ struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, co } struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { - const int i = ggml_find_node(cgraph, node); - return i >= 0 ? cgraph->grads[i] : NULL; + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL; } struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { - const int i = ggml_find_node(cgraph, node); - return i >= 0 ? cgraph->grad_accs[i] : NULL; + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL; } void ggml_graph_print(const struct ggml_cgraph * cgraph) {