From eaea67b10e547e9837796b4ad9bafb3fdcee2a6d Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Sun, 3 Nov 2024 15:18:40 +0200
Subject: [PATCH] metal : minor fixup in FA kernel (#10143)

* metal : minor fixup in FA kernel

ggml-ci

* metal : use the unrolled loop variable

* metal : remove unused var
---
 ggml/src/ggml-metal.metal | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index defde6246f129..57eb34f13ac85 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
         const short iv3 = iq3 / rv3;
 
         // load the queries from shared memory into local memory
-        float4 mq[D4];
+        float4 mq[D4/NW];
 
         for (short ii = 0; ii < D4; ii += NW) {
             short i = ii + tiisg;
-            mq[i] = (float4) sq4[i];
+            mq[ii/NW] = (float4) sq4[i];
         }
 
         // pointer to the mask
@@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
                         mk[2] = (float4) pk4[i + 2*(nb11/8)];
                         mk[3] = (float4) pk4[i + 3*(nb11/8)];
 
-                        mqk += (float4) (mq[i] * mk);
+                        mqk += (float4) (mq[ii/NW] * mk);
                     }
 
                     // reduce the results from the threads in the simdgroup
@@ -2857,8 +2857,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
                 // O = diag(ms)*O
 #pragma unroll
                 for (short ii = 0; ii < D4; ii += NW) {
-                    const short i = ii + tiisg;
-                    lo[i/NW] *= ms;
+                    lo[ii/NW] *= ms;
                 }
             }
 
@@ -2872,10 +2871,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
                     for (short ii = 0; ii < D4; ii += NW) {
                         const short i = ii + tiisg;
 
-                        lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
-                        lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
-                        lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
-                        lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
+                        lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
+                        lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
+                        lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
+                        lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
                     }
                 }
             }