-
Notifications
You must be signed in to change notification settings - Fork 10.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: non-contiguous (RMS) norm support #11659
CUDA: non-contiguous (RMS) norm support #11659
Conversation
@@ -1674,21 +1674,28 @@ struct test_silu_back : public test_case { | |||
struct test_norm : public test_case { | |||
const ggml_type type; | |||
const std::array<int64_t, 4> ne; | |||
float eps; | |||
const bool v; // whether a is a non-contiguous view |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just making a note here, we cannot keep adding parameters like to every op test case, it makes them much more complex. This will need to be refactored at some point and replaced with a generic way to create non-contiguous views for the op parameters.
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
8252615
to
57d170f
Compare
@JohannesGaessler Apply this patch to fix the Metal build: diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 3ae4bbdd1..0a264be37 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -1206,11 +1206,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction;
case GGML_OP_RMS_NORM:
- return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(src0);
+ return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ARGMAX:
return true;
case GGML_OP_NORM:
- return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(src0);
+ return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_ROPE:
{
const int mode = ((const int32_t *) op->op_params)[2]; |
This PR adds CUDA support for non-contiguous input tensors for (RMS) norm.