From 6384f6c86d710001a9f4aa93e33172cf1fa2ba79 Mon Sep 17 00:00:00 2001 From: LeoYe Date: Mon, 7 Oct 2024 10:35:50 -0700 Subject: [PATCH 1/2] update readme --- Diff-Transformer/README.md | 7 ++++++- Diff-Transformer/multihead_flashdiff_1.py | 2 +- Diff-Transformer/multihead_flashdiff_2.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/Diff-Transformer/README.md b/Diff-Transformer/README.md index 66311ca79..4017714a2 100644 --- a/Diff-Transformer/README.md +++ b/Diff-Transformer/README.md @@ -2,4 +2,9 @@ ## Approach
-
\ No newline at end of file + + +## Contents +`multihead_diffattn.py` contains naive implementation of multi-head differential attention. +`multihead_flashdiff_1.py` contains multi-head differential attention implemented with FlashAttention, for packages that support different qk/v dimensions (e.g., our [customized-flash-attention](https://github.com/xiayuqing0622/customized-flash-attention) and [xformers](https://github.com/facebookresearch/xformers)). +`multihead_flashdiff_2.py` contains multi-head differential attention implemented with FlashAttention, for packages that **do not** support different qk/v dimensions (e.g., [flash-attention](https://github.com/Dao-AILab/flash-attention)). \ No newline at end of file diff --git a/Diff-Transformer/multihead_flashdiff_1.py b/Diff-Transformer/multihead_flashdiff_1.py index e66e23756..acca7ed7a 100644 --- a/Diff-Transformer/multihead_flashdiff_1.py +++ b/Diff-Transformer/multihead_flashdiff_1.py @@ -38,7 +38,7 @@ def lambda_init_fn(depth): class MultiheadFlashDiff1(nn.Module): """ (Recommended) - DiffAttn implemented with FlashAttention, for packages that support different qkv dimensions + DiffAttn implemented with FlashAttention, for packages that support different qk/v dimensions e.g., our customized-flash-attention (https://github.com/xiayuqing0622/customized-flash-attention) and xformers (https://github.com/facebookresearch/xformers) """ def __init__( diff --git a/Diff-Transformer/multihead_flashdiff_2.py b/Diff-Transformer/multihead_flashdiff_2.py index b09d8b14f..3779a2214 100644 --- a/Diff-Transformer/multihead_flashdiff_2.py +++ b/Diff-Transformer/multihead_flashdiff_2.py @@ -37,7 +37,7 @@ def lambda_init_fn(depth): class MultiheadFlashDiff2(nn.Module): """ - DiffAttn implemented with FlashAttention, for packages that does not support different qkv dimensions + DiffAttn implemented with FlashAttention, for packages that does not support different qk/v dimensions e.g., flash-attention (https://github.com/Dao-AILab/flash-attention) """ def __init__( From f80f28dcc7f68eaf3acf7f68259790dc9b4e38ce Mon Sep 17 00:00:00 2001 From: LeoYe Date: Mon, 7 Oct 2024 10:41:27 -0700 Subject: [PATCH 2/2] update readme --- Diff-Transformer/README.md | 4 +++- Diff-Transformer/multihead_flashdiff_1.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/Diff-Transformer/README.md b/Diff-Transformer/README.md index 4017714a2..d293de611 100644 --- a/Diff-Transformer/README.md +++ b/Diff-Transformer/README.md @@ -6,5 +6,7 @@ ## Contents `multihead_diffattn.py` contains naive implementation of multi-head differential attention. -`multihead_flashdiff_1.py` contains multi-head differential attention implemented with FlashAttention, for packages that support different qk/v dimensions (e.g., our [customized-flash-attention](https://github.com/xiayuqing0622/customized-flash-attention) and [xformers](https://github.com/facebookresearch/xformers)). + +`multihead_flashdiff_1.py` contains multi-head differential attention implemented with FlashAttention, for packages that support different qk/v dimensions (e.g., our [customized-flash-attention](https://aka.ms/flash-diff) and [xformers](https://github.com/facebookresearch/xformers)). + `multihead_flashdiff_2.py` contains multi-head differential attention implemented with FlashAttention, for packages that **do not** support different qk/v dimensions (e.g., [flash-attention](https://github.com/Dao-AILab/flash-attention)). \ No newline at end of file diff --git a/Diff-Transformer/multihead_flashdiff_1.py b/Diff-Transformer/multihead_flashdiff_1.py index acca7ed7a..d2399b2b3 100644 --- a/Diff-Transformer/multihead_flashdiff_1.py +++ b/Diff-Transformer/multihead_flashdiff_1.py @@ -39,7 +39,7 @@ class MultiheadFlashDiff1(nn.Module): """ (Recommended) DiffAttn implemented with FlashAttention, for packages that support different qk/v dimensions - e.g., our customized-flash-attention (https://github.com/xiayuqing0622/customized-flash-attention) and xformers (https://github.com/facebookresearch/xformers) + e.g., our customized-flash-attention (https://aka.ms/flash-diff) and xformers (https://github.com/facebookresearch/xformers) """ def __init__( self,