Skip to content

Commit

Permalink
Update train.py (#897)
Browse files Browse the repository at this point in the history
* add flash_attn_kvpacked

* fix formatting

* accept changes from main & resolve conflicts

* Error

Signed-off-by: Dashiell Stander <dstander@protonmail.com>

* errors

Signed-off-by: Dashiell Stander <dstander@protonmail.com>

* feat(ci): add pip caching to CI

* Set training attribute appropriately

Signed-off-by: Dashiell Stander <dstander@protonmail.com>

* Split up FlashAttention methods

Signed-off-by: Dashiell Stander <dstander@protonmail.com>

* Comment out clear_cache

Signed-off-by: Dashiell Stander <dstander@protonmail.com>

* Just remove clear_cache

Signed-off-by: Dashiell Stander <dstander@protonmail.com>

* Fix pre-commit formatting

Signed-off-by: Dashiell Stander <dstander@protonmail.com>

* Changed is_pipe_parallel setting to fix pipeline-parallel inference (#866)

* Changed is_pipe_parallel setting to fix pipeline-parallel inference

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

---------

Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Quentin Anthony <qganthony@yahoo.com>

* feat: improve typing

* Added DeeperSpeed to requirements.txt

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

* Update train.py

update train.py 
1. black formatter.
2. remove unnecessary import
3. add more arguments

* Update utils.py

Black formatting
Add logic required to expand "~"

* Update train.py

removed num_proc
temporarily disabled emoji
added continuing subword prefix option ( does not work well with Bytelevel)

* Update utils.py

improve reader error handling

* Update train.py

add whitespace related handling.
add whitespace argument expose
reconstruct pre_tokenizer_list
add more whitespace to check tokenizer invertibility

* Update train.py

* Update utils.py

remove unnecessary print

* Update train.py

set dropout default to None
import path related code.
Change normalizer
change buffer_tokens
change whitespace reservation handling

* Update train.py

Clear whitespace_reservation TODO
add single_whitespace argument (might be necessary for invertibility)

* Create .gitignore

add gitignore file to ignore artifacts

* Update train.py

add directory parsing error checks
add more metrics
(tokenizer reconstructions, unicode fallback portion)

* Update preprocess.py

path handling changes
black formatting

* Update train.py

change from GPT2TokenizerFast to PreTrainedTokenizerFast class

* Update train.py

enhanced test string

* Update utils.py

add logic to handle jsonl, txt input
add logic to handle folder with jsonl,txt or arrow dataset

* Update train.py

add byte_fallback option expose
(incompatible with current transformer wrapper)
change dataset_loading with new util.py
add dataset shuffling option

* Update utils.py

fix error in loading sequence

* Update train.py

fix whitespace preservation logic

* Update train.py

simplify data loading logic.
remove unnecessary special tokens

* Update train.py

remove emoji related code

* Update train.py

add whitespace processing regex
r"\s{16,}"

* update tokenizer

add whitespace pretokenizer
(only processes looong whitespaces)

* Update train.py

* Update train.py

add camel case regex

* Update train.py

separate camel_case regex

* Update train.py

* Update train.py

---------

Signed-off-by: Dashiell Stander <dstander@protonmail.com>
Co-authored-by: Satpal Singh Rathore <satpal.code@gmail.com>
Co-authored-by: Dashiell Stander <dstander@protonmail.com>
Co-authored-by: Saurav Maheshkar <sauravvmaheshkar@gmail.com>
Co-authored-by: Stella Biderman <stellabiderman@gmail.com>
Co-authored-by: Curt Tigges <ct@curttigges.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Quentin Anthony <qganthony@yahoo.com>
  • Loading branch information
8 people authored Apr 26, 2023
1 parent 59daa5e commit de91412
Show file tree
Hide file tree
Showing 17 changed files with 762 additions and 155 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/cpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: "3.8"
cache: "pip"
cache-dependency-path: "**/requirements*.txt"

- name: Upgrade Pip
run: python -m pip install --upgrade pip
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ jobs:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: actions/setup-python@v4
with:
python-version: 3.8
cache: "pip"
cache-dependency-path: "**/requirements*.txt"
- uses: pre-commit/action@v2.0.3

update-documentation:
Expand Down
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 70f0c70
Default = 1b1e4eb

current git hash of repository

Expand Down
20 changes: 13 additions & 7 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import sys

import torch

try:
Expand All @@ -27,7 +28,7 @@
class Tee:
"""Duplicate output to both stdout/err and file"""

def __init__(self, file, err=False):
def __init__(self, file, err: bool = False) -> None:
self.file = open(file, "w")
self.err = err
if not err:
Expand All @@ -37,14 +38,14 @@ def __init__(self, file, err=False):
self.std = sys.stderr
sys.stderr = self

def __del__(self):
def __del__(self) -> None:
if not self.err:
sys.stdout = self.std
else:
sys.stderr = self.std
self.file.close()

def write(self, data):
def write(self, data) -> None:
try:
self.file.write(data)
except OSError:
Expand All @@ -54,14 +55,14 @@ def write(self, data):
except OSError:
pass

def flush(self):
def flush(self) -> None:
try:
self.file.flush()
except OSError:
pass


def human_readable_flops(num):
def human_readable_flops(num) -> str:
for unit in [
"",
"KFLOPS",
Expand All @@ -78,7 +79,7 @@ def human_readable_flops(num):
return "%.1f%s" % (num, "Yi")


def get_flops(neox_args, model, iter_time_s):
def get_flops(neox_args, model, iter_time_s) -> float:
world_size = torch.distributed.get_world_size()
ff = model.total_params * 6
attn = neox_args.seq_length * neox_args.hidden_size * neox_args.num_layers * 60
Expand Down Expand Up @@ -358,7 +359,12 @@ def add_to_logging(name):


def tb_wandb_log(
key, value, iteration_no, use_wandb, tensorboard_writer=None, all_ranks=False
key: str,
value: float,
iteration_no: int,
use_wandb: bool,
tensorboard_writer=None,
all_ranks: bool = False,
):
# logs to both tb and wandb (if present) from the zeroth rank
do_log = torch.distributed.get_rank() == 0 or all_ranks
Expand Down
278 changes: 275 additions & 3 deletions megatron/model/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
import flash_attn_cuda


def flash_attn_unpadded_unpacked_func_triton(
q, k, v, bias=None, causal=False, softmax_scale=None
):
return flash_attn_triton.flash_attn_func(q, k, v, bias, causal, softmax_scale)


def _flash_attn_forward_cuda(
q,
k,
Expand Down Expand Up @@ -186,7 +192,273 @@ def flash_attn_unpadded_qkvpacked_func_cuda(
)


def flash_attn_unpadded_qkvpacked_func_triton(
q, k, v, bias=None, causal=False, softmax_scale=None
class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward_cuda(
q,
kv[:, 0],
kv[:, 1],
torch.empty_like(q),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax,
)
ctx.save_for_backward(
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)

@staticmethod
def backward(ctx, dout, *args):
(
q,
kv,
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
rng_state,
) = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
_flash_attn_backward_cuda(
dout,
q,
kv[:, 0],
kv[:, 1],
out,
softmax_lse,
dq,
dkv[:, 0],
dkv[:, 1],
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None, None, None, None, None


def flash_attn_unpadded_kvpacked_func_cuda(
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
return flash_attn_triton.flash_attn_func(q, k, v, bias, causal, softmax_scale)
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnKVPackedFunc.apply(
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_attn_probs,
)


class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward_cuda(
q,
k,
v,
torch.empty_like(q),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax,
)
ctx.save_for_backward(
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)

@staticmethod
def backward(ctx, dout, *args):
(
q,
k,
v,
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
rng_state,
) = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward_cuda(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dk, dv, None, None, None, None, None, None, None, None


def flash_attn_unpadded_func_cuda(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_attn_probs,
)
Loading

0 comments on commit de91412

Please sign in to comment.