diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 685279a1f5..229d79caf7 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -19,8 +19,8 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - torch: ["1.13.1"] - python-version: ["3.8"] + torch: ["2.2.2"] + python-version: ["3.10"] steps: - uses: actions/checkout@v1 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 84be3a6c17..6dfb93c1ee 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -32,7 +32,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v1 with: - python-version: 3.9 + python-version: 3.10.14 architecture: x64 - name: Fetch Wenet uses: actions/checkout@v1 @@ -60,7 +60,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v1 with: - python-version: 3.x + python-version: 3.10.14 architecture: x64 - name: Fetch Wenet uses: actions/checkout@v1 @@ -88,7 +88,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v1 with: - python-version: 3.x + python-version: 3.10.14 architecture: x64 - name: Fetch Wenet uses: actions/checkout@v1 diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 1f48d32106..a6122c0a59 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -12,7 +12,7 @@ jobs: max-parallel: 20 matrix: os: [ubuntu-latest] - python-version: [3.8] + python-version: [3.10.14] steps: - name: Cache Python Packages uses: actions/cache@v1 diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 9992cbd108..990acc92a9 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -21,13 +21,13 @@ jobs: # Used to host cibuildwheel - uses: actions/setup-python@v3 with: - python-version: '3.6' + python-version: '3.10' - name: Build wheels uses: pypa/cibuildwheel@v2.11.2 env: CIBW_BUILD_VERBOSITY: 1 - CIBW_BUILD: "cp36-* cp37-* cp38-* cp39-*" + CIBW_BUILD: "cp36-* cp37-* cp38-* cp39-* cp310-*" # Disable building PyPy wheels on all platforms # Skip 32-bit builds CIBW_SKIP: "pp* *-win32 *-manylinux_i686 *-musllinux_*" diff --git a/README.md b/README.md index 4914db65a2..95824b25a4 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ git clone https://github.com/wenet-e2e/wenet.git - Create Conda env: ``` sh -conda create -n wenet python=3.8 +conda create -n wenet python=3.10 conda activate wenet conda install conda-forge::sox pip install -r requirements.txt diff --git a/examples/aishell/s0/conf/train_ebranchformer.yaml b/examples/aishell/s0/conf/train_ebranchformer.yaml index 5136f1ad9e..0e789dda3d 100644 --- a/examples/aishell/s0/conf/train_ebranchformer.yaml +++ b/examples/aishell/s0/conf/train_ebranchformer.yaml @@ -18,7 +18,7 @@ encoder_conf: activation_type: 'swish' causal: false pos_enc_layer_type: 'rel_pos' - attention_layer_type: 'rel_selfattn' + selfattention_layer_type: 'rel_selfattn' # decoder related decoder: transformer diff --git a/examples/aishell/s0/conf/train_u2++_branchformer.yaml b/examples/aishell/s0/conf/train_u2++_branchformer.yaml index 8702fbeb43..37fda91157 100644 --- a/examples/aishell/s0/conf/train_u2++_branchformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_branchformer.yaml @@ -5,7 +5,7 @@ encoder_conf: output_size: 256 use_attn: true attention_heads: 4 - attention_layer_type: rel_selfattn + selfattention_layer_type: rel_selfattn pos_enc_layer_type: rel_pos use_cgmlp: true cgmlp_linear_units: 2048 diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index b036bb7dd1..fb0c1133fe 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -55,7 +55,7 @@ dir=exp/conformer tensorboard_dir=tensorboard checkpoint= num_workers=8 -prefetch=500 +prefetch=10 # use average_checkpoint will get better result average_checkpoint=true diff --git a/examples/aishell/whisper/conf/ds_stage1.json b/examples/aishell/whisper/conf/ds_stage1.json index b708260c9f..51804c1f1d 100644 --- a/examples/aishell/whisper/conf/ds_stage1.json +++ b/examples/aishell/whisper/conf/ds_stage1.json @@ -23,40 +23,11 @@ "device": "none", "pin_memory": true }, - "offload_param": { - "device": "none", - "pin_memory": true - }, "allgather_partitions": true, "allgather_bucket_size": 5e8, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 5e8, - "contiguous_gradients" : true, - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_prefetch_bucket_size": 5e8, - "stage3_param_persistence_threshold": 1e6 - }, - "activation_checkpointing": { - "partition_activations": false, - "cpu_checkpointing": false, - "contiguous_memory_optimization": false, - "number_checkpoints": null, - "synchronize_checkpoint_boundary": false, - "profile": false - }, - "flops_profiler": { - "enabled": false, - "profile_step": 100, - "module_depth": -1, - "top_modules": 1, - "detailed": true, - "output_file": null - }, - "tensorboard": { - "enabled": false, - "output_path": "tensorboard/ds_logs/", - "job_name": "deepspeed" + "contiguous_gradients" : true } } diff --git a/examples/aishell/whisper/conf/ds_stage2.json b/examples/aishell/whisper/conf/ds_stage2.json new file mode 100644 index 0000000000..c11b7d61e4 --- /dev/null +++ b/examples/aishell/whisper/conf/ds_stage2.json @@ -0,0 +1,33 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + "gradient_clipping": 5, + "fp16": { + "enabled": false, + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "consecutive_hysteresis": false, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": true + }, + "zero_force_ds_cpu_optimizer": false, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": false, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients" : true + } +} diff --git a/examples/aishell/whisper/conf/ds_stage3.json b/examples/aishell/whisper/conf/ds_stage3.json new file mode 100644 index 0000000000..ba382935a6 --- /dev/null +++ b/examples/aishell/whisper/conf/ds_stage3.json @@ -0,0 +1,41 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + "gradient_clipping": 5, + "fp16": { + "enabled": false, + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "consecutive_hysteresis": false, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": true + }, + "zero_force_ds_cpu_optimizer": false, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "offload_param": { + "device": "none", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients" : true, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_prefetch_bucket_size": 5e8, + "stage3_param_persistence_threshold": 1e5 + } +} diff --git a/examples/aishell/whisper/run.sh b/examples/aishell/whisper/run.sh index e4045d21a9..5b4af23984 100755 --- a/examples/aishell/whisper/run.sh +++ b/examples/aishell/whisper/run.sh @@ -43,7 +43,7 @@ checkpoint=exp/whisper/large-v3/wenet_whisper.init-ctc.pt dir=exp/finetune_whisper_largev3_conv1d2 tensorboard_dir=tensorboard num_workers=8 -prefetch=500 +prefetch=10 # use average_checkpoint will get better result average_checkpoint=true diff --git a/examples/librispeech/s0/conf/train_u2++_branchformer.yaml b/examples/librispeech/s0/conf/train_u2++_branchformer.yaml index 3b79614442..f643831ca0 100644 --- a/examples/librispeech/s0/conf/train_u2++_branchformer.yaml +++ b/examples/librispeech/s0/conf/train_u2++_branchformer.yaml @@ -5,7 +5,7 @@ encoder_conf: output_size: 256 use_attn: true attention_heads: 4 - attention_layer_type: rel_selfattn + selfattention_layer_type: rel_selfattn pos_enc_layer_type: rel_pos use_cgmlp: true cgmlp_linear_units: 2048 diff --git a/examples/wenetspeech/s0/conf/train_u2++_conformer.yaml b/examples/wenetspeech/s0/conf/train_u2++_conformer.yaml index a57bfd0dc6..9edb8a940d 100755 --- a/examples/wenetspeech/s0/conf/train_u2++_conformer.yaml +++ b/examples/wenetspeech/s0/conf/train_u2++_conformer.yaml @@ -104,7 +104,7 @@ dataset_conf: grad_clip: 5 accum_grad: 4 -max_epoch: 1 # NOTE(xcsong): Configure the epoch in run.sh +max_epoch: 100 log_interval: 100 save_interval: 1000 # NOTE(xcsong): we use step_save instead of epoch_save for large datasets diff --git a/examples/wenetspeech/s0/run.sh b/examples/wenetspeech/s0/run.sh index 43d7a33b19..65e9a44069 100755 --- a/examples/wenetspeech/s0/run.sh +++ b/examples/wenetspeech/s0/run.sh @@ -47,12 +47,12 @@ train_set=train_`echo $set | tr 'A-Z' 'a-z'` dev_set=dev test_sets="test_net test_meeting" -# NOTE(xcsong): we use step_save instead of epoch_save for large datasets -epoch=100 - train_config=conf/train_u2++_conformer.yaml checkpoint= dir=exp/u2pp_conformer +tensorboard_dir=tensorboard +num_workers=8 +prefetch=10 cmvn_sampling_divisor=20 # 20 means 5% of the training data to estimate cmvn @@ -157,19 +157,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then echo "$0: using torch ddp" fi - # repeat data.list, we use step_save instead of epoch_save for large datasets - train_data=data/$train_set/data.list.repeat${epoch} - if [ ! -f "${train_data}" ]; then - echo "repeat data/$train_set/data.list ${epoch} times" - for (( i=1; i<=$epoch; i++ )) - do - cat "data/$train_set/data.list" >> "${train_data}" - done - echo "save new data.list in ${train_data}, it will be used for training" - else - echo "${train_data} already exists." - fi - echo "$0: num_nodes is $num_nodes, proc_per_node is $num_gpus" torchrun --nnodes=$num_nodes --nproc_per_node=$num_gpus --rdzv_endpoint=$HOST_NODE_ADDR \ --rdzv_id=2023 --rdzv_backend="c10d" \ @@ -177,13 +164,16 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --train_engine ${train_engine} \ --config $train_config \ --data_type "shard" \ - --train_data ${train_data} \ + --train_data data/$train_set/data.list \ --cv_data data/$dev_set/data.list \ ${checkpoint:+--checkpoint $checkpoint} \ --model_dir $dir \ + --tensorboard_dir ${tensorboard_dir} \ --ddp.dist_backend $dist_backend \ - --num_workers 2 \ + --num_workers ${num_workers} \ + --prefetch ${prefetch} \ --pin_memory \ + --timeout 1200 \ --deepspeed_config ${deepspeed_config} \ --deepspeed.save_states ${deepspeed_save_states} fi diff --git a/examples/wenetspeech/whisper/conf/ds_stage1.json b/examples/wenetspeech/whisper/conf/ds_stage1.json index 4722e5e4e1..1a04208271 100644 --- a/examples/wenetspeech/whisper/conf/ds_stage1.json +++ b/examples/wenetspeech/whisper/conf/ds_stage1.json @@ -23,19 +23,11 @@ "device": "none", "pin_memory": true }, - "offload_param": { - "device": "none", - "pin_memory": true - }, "allgather_partitions": true, "allgather_bucket_size": 5e8, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 5e8, - "contiguous_gradients" : true, - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_prefetch_bucket_size": 5e8, - "stage3_param_persistence_threshold": 1e6 + "contiguous_gradients" : true } } diff --git a/examples/wenetspeech/whisper/conf/finetune_whisper_largev3.yaml b/examples/wenetspeech/whisper/conf/finetune_whisper_largev3.yaml index cf36dfa1d9..dfccc16213 100644 --- a/examples/wenetspeech/whisper/conf/finetune_whisper_largev3.yaml +++ b/examples/wenetspeech/whisper/conf/finetune_whisper_largev3.yaml @@ -108,7 +108,7 @@ dataset_conf: grad_clip: 5 accum_grad: 8 -max_epoch: 1 # NOTE(xcsong): Configure the epoch in run.sh +max_epoch: 100 log_interval: 100 save_interval: 1000 # NOTE(xcsong): we use step_save instead of epoch_save for large datasets diff --git a/examples/wenetspeech/whisper/run.sh b/examples/wenetspeech/whisper/run.sh index bbd177233e..886ac3523a 100755 --- a/examples/wenetspeech/whisper/run.sh +++ b/examples/wenetspeech/whisper/run.sh @@ -44,13 +44,12 @@ train_set=train_l dev_set=dev test_sets="test_net test_meeting" -epoch=100 train_config=conf/finetune_whisper_largev3.yaml checkpoint=exp/whisper/large-v3/wenet_whisper.init-ctc.pt dir=exp/finetune_whisper_largev3 tensorboard_dir=tensorboard -num_workers=1 -prefetch=500 +num_workers=8 +prefetch=10 # use average_checkpoint will get better result average_checkpoint=true @@ -92,19 +91,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then echo "$0: using torch ddp" fi - # repeat data.list, we use step_save instead of epoch_save for large datasets - train_data=data/$train_set/data.list.repeat${epoch} - if [ ! -f "${train_data}" ]; then - echo "repeat data/$train_set/data.list ${epoch} times" - for (( i=1; i<=$epoch; i++ )) - do - cat "data/$train_set/data.list" >> "${train_data}" - done - echo "save new data.list in ${train_data}, it will be used for training" - else - echo "${train_data} already exists." - fi - # NOTE(xcsong): Both ddp & deepspeed can be launched by torchrun # NOTE(xcsong): To unify single-node & multi-node training, we add # all related args. You should change `nnodes` & @@ -128,7 +114,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --train_engine ${train_engine} \ --config $train_config \ --data_type $data_type \ - --train_data ${train_data} \ + --train_data data/$train_set/data.list \ --cv_data data/$dev_set/data.list \ ${checkpoint:+--checkpoint $checkpoint} \ --model_dir $dir \ diff --git a/requirements.txt b/requirements.txt index 1f498880d5..e78b45af78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ cpplint==1.6.1 torch>=2.1.2 torchaudio>=2.1.2 tqdm -deepspeed<0.13.0 +deepspeed>=0.14.0 librosa openai-whisper pre-commit==3.5.0 diff --git a/test/wenet/dataset/test_datapipes.py b/test/wenet/dataset/test_datapipes.py index f269788c9e..d36afa9c82 100644 --- a/test/wenet/dataset/test_datapipes.py +++ b/test/wenet/dataset/test_datapipes.py @@ -7,9 +7,10 @@ from wenet.dataset.datapipes import (RepeatDatapipe, SortDataPipe, WenetRawDatasetSource, WenetTarShardDatasetSource) -from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, padding, - parse_json, compute_fbank, - detect_language, detect_task) +from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, + feats_length_fn, padding, parse_json, + compute_fbank, detect_language, + detect_task) @pytest.mark.parametrize("data_list", [ @@ -106,7 +107,8 @@ def test_dynamic_batch_datapipe(data_list): max_frames_in_batch = 10000 dataset = dataset.dynamic_batch( window_class=DynamicBatchWindow(max_frames_in_batch), - wrapper_class=padding) + wrapper_class=padding, + elem_size_fn=feats_length_fn) dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, diff --git a/test/wenet/transformer/test_attention.py b/test/wenet/transformer/test_attention.py index ca7e63d92c..ba2869d7d2 100644 --- a/test/wenet/transformer/test_attention.py +++ b/test/wenet/transformer/test_attention.py @@ -64,7 +64,8 @@ def test_multi_head_attention_sdpa(args): output_with_sdpa * mask.transpose(1, 2), atol=9e-7, ) - assert torch.allclose(cache, cache_with_sdpa) + assert torch.allclose(cache[0], cache_with_sdpa[0]) + assert torch.allclose(cache[1], cache_with_sdpa[1]) n_blocks = 12 torch.manual_seed(777) @@ -110,7 +111,8 @@ def test_multi_head_attention_sdpa(args): atol=9e-7, rtol=9e-4, ) - assert torch.allclose(cache, cache_with_sdpa) + assert torch.allclose(cache[0], cache_with_sdpa[0]) + assert torch.allclose(cache[1], cache_with_sdpa[1]) q = output @@ -170,7 +172,8 @@ def test_rel_position_multi_head_attention_sdpa(args): output_with_sdpa * mask.transpose(1, 2), atol=9e-7, ) - assert torch.allclose(cache, cache_with_sdpa) + assert torch.allclose(cache[0], cache_with_sdpa[0]) + assert torch.allclose(cache[1], cache_with_sdpa[1]) n_blocks = 12 torch.manual_seed(777) @@ -220,7 +223,8 @@ def test_rel_position_multi_head_attention_sdpa(args): atol=9e-7, rtol=9e-4, ) - assert torch.allclose(cache, cache_with_sdpa) + assert torch.allclose(cache[0], cache_with_sdpa[0]) + assert torch.allclose(cache[1], cache_with_sdpa[1]) q = output diff --git a/test/wenet/whisper/test_whisper.py b/test/wenet/whisper/test_whisper.py index e336950fd0..0f1e994541 100644 --- a/test/wenet/whisper/test_whisper.py +++ b/test/wenet/whisper/test_whisper.py @@ -318,7 +318,8 @@ def test_model(model, audio_path): attn_ln_x, attn_ln_x, masks, - cache=torch.zeros((0, 0, 0, 0))) + cache=(torch.zeros((0, 0, 0, 0)), + torch.zeros(0, 0, 0, 0))) wenet_layers_output.append({ "name": "enc.layer{}.attn".format(i), "value": x_att.clone() diff --git a/tools/compute_cmvn_stats.py b/tools/compute_cmvn_stats.py index 4dcad825d2..c68929436c 100755 --- a/tools/compute_cmvn_stats.py +++ b/tools/compute_cmvn_stats.py @@ -30,8 +30,7 @@ def __call__(self, batch): value = item[1].strip().split(",") assert len(value) == 3 or len(value) == 1 wav_path = value[0] - sample_rate = torchaudio.info( - wav_path).sample_rate + sample_rate = torchaudio.info(wav_path).sample_rate resample_rate = sample_rate # len(value) == 3 means segmented wav.scp, # len(value) == 1 means original wav.scp diff --git a/tools/wav2dur.py b/tools/wav2dur.py index d416b1ad96..296114961b 100755 --- a/tools/wav2dur.py +++ b/tools/wav2dur.py @@ -5,7 +5,6 @@ import torchaudio - scp = sys.argv[1] dur_scp = sys.argv[2] diff --git a/wenet/LLM/causal_model.py b/wenet/LLM/causal_model.py new file mode 100644 index 0000000000..b9192ce15d --- /dev/null +++ b/wenet/LLM/causal_model.py @@ -0,0 +1,208 @@ +from typing import Dict, List, Optional, Union +import torch +from wenet.LLM.decoder import DecoderOnly +from wenet.LLM.sampler import sampler +from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss +from wenet.utils.common import IGNORE_ID, th_accuracy +from wenet.utils.mask import make_pad_mask, subsequent_mask + + +class CausalLM(torch.nn.Module): + + def __init__( + self, + vocab_size: int, + decoder: DecoderOnly, + special_tokens: dict, + tie_word_embedding: bool = False, + linear_bias: bool = False, + ignore_id: int = IGNORE_ID, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + ) -> None: + super().__init__() + del special_tokens + + self.embed = torch.nn.Embedding(vocab_size, decoder.hidden_size) + self.out = torch.nn.Linear(decoder.hidden_size, + vocab_size, + bias=linear_bias) + + self.decoder = decoder + self.vocab_size = vocab_size + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + self.tie_word_embedding = tie_word_embedding + self.ignore_id = ignore_id + + @torch.jit.unused + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ Forward for training + """ + text = batch['feats'].to(device) + target = batch['target'].to(device) + text_length = batch['feats_lengths'].to(device) + + mask = ~make_pad_mask(text_length, max_len=text.size(1)).unsqueeze( + 1) # (B,1,L) + causal_mask = subsequent_mask( + mask.size(-1), device=mask.device).unsqueeze(0) # (1,L,L) + att_mask = causal_mask & mask # (B, L, L) + + embeding = self.embed(text) + decoder_out = self.out(self.decoder(embeding, + att_mask)[0]) # (B, L, vocab_size) + loss = self.criterion_att(decoder_out, target) + acc = th_accuracy(decoder_out.view(-1, self.vocab_size), + target, + ignore_label=self.ignore_id) + + return { + "loss": loss, + "ppl": torch.exp(loss.detach()), + "th_accuracy": acc + } + + def tie_or_clone_weights(self, jit_mode: bool): + if not self.tie_word_embedding: + return + if jit_mode: + self.out.weight = torch.nn.Parameter(self.embed.weight.clone()) + else: + self.out.weight = self.embed.weight + # TODO(Mddct): whether to deal bias for other llm model + + @torch.jit.unused + @torch.inference_mode() + def generate( + self, + prompts_tokens: List[List[int]], + device: torch.device, + stop_tokens: List[int], + dtype: torch.dtype = torch.float32, + output_len: int = 100, + temperature: Union[float, None] = 0.95, + top_p: float = 1.0, + top_k: int = 100, + ) -> List[List[int]]: + """Generates responses for given prompts using Gemma model.""" + # If a single prompt is provided, treat it as a batch of 1. + batch_size = len(prompts_tokens) + min_prompt_len = min(len(p) for p in prompts_tokens) + max_prompt_len = max(len(p) for p in prompts_tokens) + max_seq_len = max_prompt_len + output_len + assert max_seq_len <= self.decoder.pos_enc.max_len + + # build KV caches + kv_caches = [] + for _ in range(len(self.decoder.decoders)): + size = (batch_size, 0, self.decoder.n_kv_head, + self.decoder.head_dim) + k_cache = torch.zeros(size=size, dtype=dtype, device=device) + v_cache = torch.zeros(size=size, dtype=dtype, device=device) + kv_caches.append((k_cache, v_cache)) + + # prepare inputs + token_ids_tensor = torch.full((batch_size, max_seq_len), + IGNORE_ID, + dtype=torch.int64, + device=device) + input_token_ids_tensor = torch.full((batch_size, min_prompt_len), + IGNORE_ID, + dtype=torch.int64, + device=device) + # right padding + for i, p in enumerate(prompts_tokens): + token_ids_tensor[i, :len(p)] = torch.tensor(p) + input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( + p[:min_prompt_len]) + + prompt_mask_tensor = token_ids_tensor != IGNORE_ID + input_positions_tensor = torch.arange(0, + min_prompt_len, + dtype=torch.int64).to(device) + mask_tensor = torch.ones((1, 1, max_seq_len, max_seq_len), + dtype=torch.bool) + mask_tensor = torch.tril(mask_tensor).to(device) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + att_mask = curr_mask_tensor.squeeze( + 1)[:, :min_prompt_len, :min_prompt_len] + output_positions_tensor = torch.LongTensor([min_prompt_len - 1 + ]).to(device) + temperatures_tensor = None if not temperature else torch.FloatTensor( + [temperature] * batch_size).to(device) + top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) + top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) + output_index = torch.tensor(min_prompt_len, + dtype=torch.int64).to(device) + + input_token_embeding = self.embed(input_token_ids_tensor) + offset = torch.tensor([0] * len(prompts_tokens)).to(device) + input_offset = offset + + stop_tokens_tensor = torch.tensor(stop_tokens, device=device) + # Prefill up to min_prompt_len tokens, then treat other prefill as + # decode and ignore output. + for i in range(max_seq_len - min_prompt_len): + decoder_out, kv_caches, = self.decoder( + input_token_embeding, + att_mask, + input_offset, + kv_caches, + ) + decoder_out = self.out(decoder_out) + decoder_out = decoder_out.index_select(1, output_positions_tensor) + next_token_ids = sampler( + decoder_out, + temperatures_tensor, + top_ps_tensor, + top_ks_tensor, + ) + curr_prompt_mask = prompt_mask_tensor.index_select( + 1, output_index).squeeze(dim=1) + curr_token_ids = token_ids_tensor.index_select( + 1, output_index).squeeze(dim=1) + output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, + next_token_ids).unsqueeze(dim=1) + token_ids_tensor.index_copy_(1, output_index, output_token_ids) + + input_token_ids_tensor = output_token_ids + input_token_embeding = self.embed(input_token_ids_tensor) + + input_positions_tensor = output_index.unsqueeze(dim=-1) + curr_mask_tensor = mask_tensor.index_select( + 2, input_positions_tensor) + att_mask = curr_mask_tensor.squeeze(1)[:, :output_index + + 1, :output_index + 1] + + output_positions_tensor = torch.tensor( + 0, dtype=torch.int64).to(device) + input_offset = offset + output_index.unsqueeze(-1) + output_index = output_index + 1 + + if all(torch.isin(next_token_ids, stop_tokens_tensor)): + break + + token_ids = token_ids_tensor.tolist() + results = [] + for i, tokens in enumerate(token_ids): + trimmed_output = tokens[len(prompts_tokens[i] + ):len(prompts_tokens[i]) + output_len] + for stop_token in stop_tokens: + try: + eos_index = trimmed_output.index(stop_token) + trimmed_output = trimmed_output[:eos_index] + break + except Exception: + continue + results.append(trimmed_output) + + return results diff --git a/wenet/LLM/decoder.py b/wenet/LLM/decoder.py new file mode 100644 index 0000000000..b25ee75dd6 --- /dev/null +++ b/wenet/LLM/decoder.py @@ -0,0 +1,161 @@ +from functools import partial +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint as ckpt +from wenet.transformer.attention import T_CACHE + +from wenet.transformer.encoder_layer import TransformerEncoderLayer +from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES, + WENET_ATTENTION_CLASSES, + WENET_EMB_CLASSES, WENET_MLP_CLASSES, + WENET_NORM_CLASSES) +from wenet.utils.common import mask_to_bias + + +class DecoderOnly(torch.nn.Module): + + def __init__( + self, + n_kv_head: int, + head_dim: int, + hidden_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + normalize_before: bool = True, + query_bias: bool = False, + key_bias: bool = False, + value_bias: bool = False, + mlp_bias: bool = False, + activation_type: str = "gelu", + gelu_approximate: Union[str, None] = None, + max_position_embeding: int = 8192, + mlp_type: str = 'gated', + layer_norm_type: str = 'rms_norm', + norm_eps: float = 1e-5, + rms_norm_offset: bool = True, + selfattention_layer_type: str = "rope_abs_selfattn", + use_sdpa: bool = False, + gradient_checkpointing: bool = False, + rope_theta: float = 10000.0, + rope_style: str = 'google', + scale_embed: bool = True, + ) -> None: + super().__init__() + + assert selfattention_layer_type in ['rope_abs_selfattn'] + self.pos_enc = WENET_EMB_CLASSES["rope_pos"]( + hidden_size, + head_dim, + max_len=max_position_embeding, + dropout_rate=positional_dropout_rate, + rope_theta=rope_theta, + scale=scale_embed) + if activation_type == "gelu" and gelu_approximate is not None: + activation = WENET_ACTIVATION_CLASSES['gelu']( + approximate=gelu_approximate) + else: + activation = WENET_ACTIVATION_CLASSES[activation_type]() + + mlp_class = WENET_MLP_CLASSES[mlp_type] + self.num_blocks = num_blocks + # TODO: support lora & refactor lora + self.decoders = torch.nn.ModuleList([ + TransformerEncoderLayer( + hidden_size, + WENET_ATTENTION_CLASSES[selfattention_layer_type]( + attention_heads, + hidden_size, + attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, + style=rope_style), + mlp_class(hidden_size, linear_units, dropout_rate, activation, + mlp_bias), + dropout_rate, + normalize_before, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + rms_norm_offset=rms_norm_offset, + ) for _ in range(self.num_blocks) + ]) + self.pre_norm = normalize_before + self.final_norm: Optional[torch.nn.Module] = None + if self.pre_norm: + norm_class = WENET_NORM_CLASSES[layer_norm_type] + if layer_norm_type == "rms_norm": + norm_class = partial( + norm_class, + add_unit_offset=rms_norm_offset, + ) + self.final_norm = norm_class(hidden_size, eps=norm_eps) + + self.n_kv_head = n_kv_head + self.head_dim = head_dim + self._hidden_size = hidden_size + self.use_sdpa = use_sdpa + self.gradient_checkpointing = gradient_checkpointing + + def forward( + self, + input: torch.Tensor, + att_mask: torch.Tensor, + input_position: Union[int, torch.Tensor] = 0, + kv_caches: Optional[List[T_CACHE]] = None, + ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: + xs, pos_emb = self.pos_enc(input, offset=input_position) + if self.use_sdpa: + att_mask = mask_to_bias(att_mask, xs.dtype) + + if self.gradient_checkpointing and self.training: + xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb) + else: + xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb, + kv_caches) + if self.pre_norm and self.final_norm is not None: + xs = self.final_norm(xs) + return xs, kv_caches + + def forward_layers( + self, + xs: torch.Tensor, + att_mask: torch.Tensor, + pos_emb: torch.Tensor, + kv_caches: Optional[List[T_CACHE]] = None, + ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: + if self.training: + for (i, layer) in enumerate(self.decoders): + xs, _, _, _ = layer(xs, att_mask, pos_emb) + new_kv_caches = kv_caches + else: + assert kv_caches is not None + new_kv_caches = [] + for (i, layer) in enumerate(self.decoders): + xs, _, new_kv_cache, _ = layer(xs, + att_mask, + pos_emb, + att_cache=(kv_caches[i][0], + kv_caches[i][1])) + new_kv_caches.append(new_kv_cache) + + return xs, new_kv_caches + + @torch.jit.ignore(drop=True) + def forward_layers_checkpointed(self, xs: torch.Tensor, + att_mask: torch.Tensor, + pos_emb: torch.Tensor) -> torch.Tensor: + for layer in self.decoders: + xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask, + pos_emb) + return xs + + @property + def hidden_size(self): + return self._hidden_size diff --git a/wenet/LLM/sampler.py b/wenet/LLM/sampler.py new file mode 100644 index 0000000000..19f0d5cdaf --- /dev/null +++ b/wenet/LLM/sampler.py @@ -0,0 +1,43 @@ +from typing import Union +import torch + + +# modified from https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L26 +@torch.no_grad() +def sampler( + logits: torch.Tensor, + temperatures: Union[torch.Tensor, None], + top_ps: torch.Tensor, + top_ks: torch.Tensor, +) -> torch.Tensor: + assert logits.size(1) == 1 + logits = logits.squeeze(1) # (batch_size, vocab_size) + if temperatures is None: + return torch.argmax(logits, dim=-1).squeeze(dim=-1) + + # Apply temperature scaling. + logits.div_(temperatures.unsqueeze(dim=1)) + + # Calculate probabilities with softmax. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + + # Apply top-p, top-k. + probs_sum = torch.cumsum(probs_sort, dim=-1) + top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) + probs_sort = torch.where(top_ps_mask, 0, probs_sort) + + top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) + top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) + top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1) + probs_sort = torch.where(top_ks_mask, 0, probs_sort) + + # Re-normalization. + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + probs = torch.gather(probs_sort, + dim=-1, + index=torch.argsort(probs_idx, dim=-1)) + + next_token_ids = torch.multinomial(probs, num_samples=1, + replacement=True).squeeze(dim=-1) + return next_token_ids diff --git a/wenet/LLM/script/config.py b/wenet/LLM/script/config.py new file mode 100644 index 0000000000..c37959af45 --- /dev/null +++ b/wenet/LLM/script/config.py @@ -0,0 +1,205 @@ +import dataclasses +from typing import Dict, Optional, Union + +import yaml + + +# https://github.com/google/gemma_pytorch/blob/main/gemma/config.py#L32 +@dataclasses.dataclass +class Config: + vocab_size: int = 256000 + # The maximum sequence length that this model might ever be used with. + max_position_embeddings: int = 8192 + # The number of blocks in the model. + num_hidden_layers: int = 28 + # The number of attention heads used in the attention layers of the model. + num_attention_heads: int = 16 + # The number of key-value heads for implementing attention. + num_key_value_heads: int = 16 + # The hidden size of the model. + hidden_size: int = 3072 + # The dimension of the MLP representations. + intermediate_size: int = 24576 + # The number of head dimensions. + head_dim: int = 256 + # The epsilon used by the rms normalization layers. + rms_norm_eps: float = 1e-6 + # tope theta + rope_theta: float = 500000.0 + # rope style: google or llama + rope_style: str = 'google' + # rms_norm offset + rms_norm_offset: bool = True + # activation type + activation_type: str = 'gelu' + # gelu approximate + gelu_approximate: Union[str, None] = None + # The dtype of the weights. + dtype: str = 'bfloat16' + + # scale embed + scale_embed: bool = True + + def to_wenet_config(self) -> Dict: + configs = {} + configs['max_position_embeding'] = self.max_position_embeddings + configs['num_blocks'] = self.num_hidden_layers + configs['attention_heads'] = self.num_attention_heads + configs['n_kv_head'] = self.num_key_value_heads + configs['head_dim'] = self.head_dim + configs['hidden_size'] = self.hidden_size + configs['linear_units'] = self.intermediate_size + configs['norm_eps'] = self.rms_norm_eps + configs['rope_theta'] = self.rope_theta + configs['activation_type'] = self.activation_type + configs['gelu_approximate'] = self.gelu_approximate + configs['rope_style'] = self.rope_style + configs['rms_norm_offset'] = self.rms_norm_offset + configs['scale_embed'] = self.scale_embed + return configs + + +def wenet_llm_tokenizer_conf(config: Config, tokenizer_path: str, + special_tokens: Dict) -> Dict: + configs = {} + configs['tokenizer'] = 'huggingface' + configs['tokenizer_conf'] = {} + configs['tokenizer_conf']['model'] = tokenizer_path + configs['tokenizer_conf']['special_tokens'] = special_tokens + return configs + + +def wenet_llm_dataset_and_train_conf(config: Config, + template: str = 'gemma') -> Dict: + configs = {} + configs['dataset'] = "llm" + configs['dataset_conf'] = {} + configs['dataset_conf']['filter_conf'] = {} + configs['dataset_conf']['filter_conf'][ + 'token_max_length'] = config.max_position_embeddings + configs['dataset_conf']['filter_conf']['token_min_length'] = 1 + configs['dataset_conf']['shuffle'] = True + configs['dataset_conf']['shuffle_conf'] = {} + configs['dataset_conf']['shuffle_conf']['shuffle_size'] = 1500 + configs['dataset_conf']['shuffle_list'] = True + configs['dataset_conf']['shuffle_list_conf'] = {} + configs['dataset_conf']['shuffle_list_conf']['shuffle_size'] = 15000 + configs['dataset_conf']['sort'] = True + configs['dataset_conf']['sort_conf'] = {} + configs['dataset_conf']['sort_conf']['sort_size'] = 500 + configs['dataset_conf']['batch_conf'] = {} + configs['dataset_conf']['batch_conf']['batch_type'] = 'dynamic' + configs['dataset_conf']['batch_conf']['max_frames_in_batch'] = 12000 + + configs['dataset_conf']['data_style'] = 'sft' + configs['dataset_conf']['data_style_conf'] = {} + configs['dataset_conf']['data_style_conf']['add_bos'] = True + configs['dataset_conf']['data_style_conf']['add_eos'] = True + configs['dataset_conf']['data_style_conf']['template'] = template + configs['dataset_conf']['shift'] = True + + configs['grad_clip'] = 5 + configs['accum_grad'] = 4 + configs['max_epoch'] = 100 + configs['log_interval'] = 100 + configs['save_interval'] = 3000 + + configs['optim'] = "adam" + configs['optim_conf'] = {} + configs['optim_conf']['lr'] = 0.0005 + configs['scheduler'] = "warmuplr" + configs['scheduler_conf'] = {} + configs['scheduler_conf']['warmup_steps'] = 12000 + return configs + + +def wenet_decoderonly_conf(config: Config): + configs = {} + configs['decoder'] = 'decoder_only' + configs['decoder_conf'] = config.to_wenet_config() + configs['decoder_conf']['dropout_rate'] = 0.0 + configs['decoder_conf']['attention_dropout_rate'] = 0.0 + configs['decoder_conf']['positional_dropout_rate'] = 0.0 + configs['decoder_conf']['gradient_checkpointing'] = True + configs['decoder_conf']['normalize_before'] = True + configs['decoder_conf']['use_sdpa'] = True + return configs + + +def wenet_model_conf(config: Config, tie_word_embedding: bool = True): + configs = {} + configs['output_dim'] = config.vocab_size + configs['model'] = "causal_lm" + configs['model_conf'] = {} + configs['model_conf']['linear_bias'] = False + configs['model_conf']['tie_word_embedding'] = tie_word_embedding + configs['model_conf']['lsm_weight'] = 0.1 + configs['model_conf']['length_normalized_loss'] = False + return configs + + +def convert_to_wenet_yaml(config: Config, + wenet_yaml_path: str, + tokenizer_path, + template: str = 'gemma', + tie_word_embedding: bool = True, + special_tokens: Optional[Dict] = None): + configs = {} + configs.update( + wenet_llm_tokenizer_conf(config, tokenizer_path, special_tokens)) + configs.update(wenet_decoderonly_conf(config)) + configs.update( + wenet_model_conf(config, tie_word_embedding=tie_word_embedding)) + configs.update(wenet_llm_dataset_and_train_conf(config, template=template)) + + with open(wenet_yaml_path, '+w') as f: + f.write(yaml.dump(configs)) + f.flush() + + print(configs) + + +def gemma_config_for_7b() -> Config: + return Config(rope_theta=10000.0, gelu_approximate='tanh') + + +def gemma_config_for_2b() -> Config: + return Config(num_hidden_layers=18, + num_attention_heads=8, + num_key_value_heads=1, + hidden_size=2048, + intermediate_size=16384, + rope_theta=10000.0, + gelu_approximate='tanh') + + +def llama3_config_for_8b() -> Config: + return Config(vocab_size=128256, + num_hidden_layers=32, + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=128, + intermediate_size=14336, + rms_norm_eps=1e-5, + rope_theta=500000.0, + activation_type='swish', + rms_norm_offset=False, + rope_style='llama', + scale_embed=False) + + +def llama3_config_for_70b() -> Config: + return Config(vocab_size=128256, + num_hidden_layers=80, + hidden_size=8192, + head_dim=128, + num_attention_heads=64, + num_key_value_heads=8, + intermediate_size=28672, + rms_norm_eps=1e-5, + rope_theta=500000.0, + activation_type='swish', + rms_norm_offset=False, + rope_style='llama', + scale_embed=False) diff --git a/wenet/LLM/script/convert_main.py b/wenet/LLM/script/convert_main.py new file mode 100644 index 0000000000..30bca315af --- /dev/null +++ b/wenet/LLM/script/convert_main.py @@ -0,0 +1,86 @@ +import argparse + +import os + +import torch + +from wenet.LLM.script.config import (convert_to_wenet_yaml, + gemma_config_for_2b, gemma_config_for_7b, + llama3_config_for_70b, + llama3_config_for_8b) +from wenet.LLM.script.gemma_config import (convert_to_wenet_state_dict as + gemma_convert_ckpt_fn, + gemma_special_tokens) +from wenet.LLM.script.llama3_config import (convert_to_wenet_state_dict as + llama3_convert_ckpt_fn, + llama3_special_tokens) + + +def get_args(): + parser = argparse.ArgumentParser(description='load and convert llm ckpt') + parser.add_argument('--ckpt', + required=True, + help='llama3: https://llama.meta.com/llama-downloads/ \ + \ngemma: https://www.kaggle.com/models/google/gemma/frameworks/pyTorch' + ) + parser.add_argument('--model_size', type=str, required=True) + parser.add_argument('--model_name', type=str, required=True) + parser.add_argument('--output_dir', + default='.', + help='output file in wenet\'s style') + args = parser.parse_args() + return args + + +MODEL = { + "gemma": { + "2b": (gemma_config_for_2b(), 'google/gemma-2b'), + "7b": (gemma_config_for_7b(), 'google/gemma-7b'), + "ckpt_fn": gemma_convert_ckpt_fn, + 'tie_word_embeding': True, + 'special_tokens_fn': gemma_special_tokens, + }, + "llama3": { + "8b": (llama3_config_for_8b(), 'meta-llama/Meta-Llama-3-8B'), + "70b": (llama3_config_for_70b(), 'meta-llama/Meta-Llama-3-70B'), + "ckpt_fn": llama3_convert_ckpt_fn, + 'tie_word_embeding': False, + 'special_tokens_fn': llama3_special_tokens, + }, +} + + +def main(): + args = get_args() + args.jit = False + model_size = args.model_size + model_name = args.model_name + assert model_name in MODEL.keys() + all(model_size in size.keys() for size in MODEL.values()) + config = MODEL[model_name][model_size][0] + args.tokenizer = MODEL[model_name][model_size][1] + + os.makedirs(args.output_dir, exist_ok=True) + + checkpoint = torch.load(args.ckpt, map_location="cpu") + if model_name == 'gemma': + checkpoint = checkpoint["model_state_dict"] + wenet_ckpt_path = os.path.join(args.output_dir, + 'wenet_' + os.path.basename(args.ckpt)) + wenet_ckpt_path = os.path.splitext(wenet_ckpt_path)[0] + ".pt" + convert_fn = MODEL[model_name]['ckpt_fn'] + convert_fn(checkpoint, wenet_ckpt_path, config) + + wenet_yaml_path = os.path.join(args.output_dir, 'train.yaml') + convert_to_wenet_yaml( + config, + wenet_yaml_path, + args.tokenizer, + template=model_name, + tie_word_embedding=MODEL[model_name]['tie_word_embeding'], + special_tokens=MODEL[model_name]['special_tokens_fn'](args.tokenizer, + config)) + + +if __name__ == '__main__': + main() diff --git a/wenet/LLM/script/gemma_config.py b/wenet/LLM/script/gemma_config.py new file mode 100644 index 0000000000..0ff7362747 --- /dev/null +++ b/wenet/LLM/script/gemma_config.py @@ -0,0 +1,83 @@ +import torch + +from wenet.LLM.script.config import Config +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer + + +def convert_to_wenet_state_dict(gemma_state_dict, wenet_state_dict_path, + config: Config): + + print("==============start CKPT Conversion =========================") + wenet_state_dict = {} + for name in gemma_state_dict.keys(): + old_name = name + # embed + name = name.replace('embedder.weight', 'embed.weight') + + # layers to decoders + name = name.replace('model.layers', 'decoder.decoders') + + if 'self_attn.qkv_proj' in name: + # att weight + i_layer = name.split('.')[2] + layer_prefix = 'decoder.decoders.' + i_layer + linear_q_name = layer_prefix + '.self_attn.linear_q.weight' + linear_k_name = layer_prefix + '.self_attn.linear_k.weight' + linear_v_name = layer_prefix + '.self_attn.linear_v.weight' + + start = 0 + offset = config.num_attention_heads * config.head_dim + linear_q_value = gemma_state_dict[old_name][start:offset, :] + start = offset + offset = offset + config.head_dim * config.num_key_value_heads + linear_k_value = gemma_state_dict[old_name][start:offset, :] + start = offset + linear_v_value = gemma_state_dict[old_name][start:, :] + wenet_state_dict[linear_q_name] = linear_q_value + wenet_state_dict[linear_k_name] = linear_k_value + wenet_state_dict[linear_v_name] = linear_v_value + elif name == 'freqs_cis': + # rope position embeding + name = 'decoder.pos_enc.pe' + pe = torch.view_as_real(gemma_state_dict[old_name].unsqueeze(0)) + wenet_state_dict[name] = pe + else: + # att out dim + name = name.replace('self_attn.o_proj', 'self_attn.linear_out') + + # mlp + name = name.replace('mlp.gate_proj', 'feed_forward.gate') + name = name.replace('mlp.up_proj', 'feed_forward.w_1') + name = name.replace('mlp.down_proj', 'feed_forward.w_2') + + # pre ln (rms norm) + name = name.replace('input_layernorm', 'norm1') + # before mlp ln: (rms norm) + name = name.replace('post_attention_layernorm', 'norm2') + # final norm + name = name.replace('model.norm.weight', + 'decoder.final_norm.weight') + + wenet_state_dict[name] = gemma_state_dict[old_name] + # NOTE(Mddct): tie weight + wenet_state_dict['out.weight'] = wenet_state_dict['embed.weight'] + print("Saving {} ckpt to {}...".format(config.dtype, + wenet_state_dict_path)) + torch.save(wenet_state_dict, wenet_state_dict_path) + print( + "DONE\n===================- End CKPT Conversion ====================\n" + ) + + +def gemma_special_tokens(tokenizer_path, config: Config): + tokenizer = HuggingFaceTokenizer(tokenizer_path) + assert config.vocab_size == tokenizer.vocab_size() + special_tokens = {} + bos = tokenizer.tokens2ids([""])[0] + eos = tokenizer.tokens2ids([""])[0] + unk = tokenizer.tokens2ids([""])[0] + special_tokens = {} + special_tokens[''] = bos + special_tokens[''] = eos + special_tokens[''] = unk + return special_tokens diff --git a/wenet/LLM/script/llama3_config.py b/wenet/LLM/script/llama3_config.py new file mode 100644 index 0000000000..074f14a147 --- /dev/null +++ b/wenet/LLM/script/llama3_config.py @@ -0,0 +1,74 @@ +from typing import Dict +import torch +from wenet.LLM.script.config import Config + +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer + + +def llama3_special_tokens(tokenizer_path, config: Config) -> Dict: + tokenizer = HuggingFaceTokenizer(tokenizer_path) + assert config.vocab_size == tokenizer.vocab_size() + # "<|reserved_special_token_0|>", + # "<|reserved_special_token_1|>", + # "<|reserved_special_token_2|>", + # "<|reserved_special_token_3|>", + shi = tokenizer.tokens2ids(["<|start_header_id|>"])[0] + ehi = tokenizer.tokens2ids(["<|end_header_id|>"])[0] + bos = tokenizer.tokens2ids(["<|begin_of_text|>"])[0] + eos = tokenizer.tokens2ids(["<|end_of_text|>"])[0] + eoti = tokenizer.tokens2ids(["<|eot_id|>"])[0] + special_tokens = {} + special_tokens['<|begin_of_text|>'] = bos + special_tokens['<|end_of_text|>'] = eos + special_tokens['<|eot_id|>'] = eoti + special_tokens['<|start_header_id|>'] = shi + special_tokens['<|end_header_id|>'] = ehi + return special_tokens + + +def convert_to_wenet_state_dict(Llama3_state_dict, wenet_state_dict_path, + config: Config): + + wenet_state_dict = {} + + print("==============start CKPT Conversion =========================") + conformer_state_dict = Llama3_state_dict + wenet_state_dict = {} + for name in conformer_state_dict.keys(): + old_name = name + # embed + name = name.replace('tok_embeddings.weight', 'embed.weight') + # output + name = name.replace('output.weight', 'out.weight') + # layers to decoders + name = name.replace('layers', 'decoder.decoders') + if 'attention' in name: + # pre ln (rms norm) + name = name.replace('attention_norm.weight', 'norm1.weight') + # att weight + name = name.replace('.attention.wq.weight', + '.self_attn.linear_q.weight') + name = name.replace('.attention.wk.weight', + '.self_attn.linear_k.weight') + name = name.replace('.attention.wv.weight', + '.self_attn.linear_v.weight') + # att out dim + name = name.replace('attention.wo', 'self_attn.linear_out') + else: + # mlp + name = name.replace('feed_forward.w1', 'feed_forward.gate') + name = name.replace('feed_forward.w3', 'feed_forward.w_1') + name = name.replace('feed_forward.w2', 'feed_forward.w_2') + + # before mlp ln: (rms norm) + name = name.replace('ffn_norm', 'norm2') + wenet_state_dict[name] = conformer_state_dict[old_name] + # final norm weight + wenet_state_dict['decoder.final_norm.weight'] = conformer_state_dict[ + 'norm.weight'] + print("Saving {} ckpt to {}...".format(config.dtype, + wenet_state_dict_path)) + torch.save(wenet_state_dict, wenet_state_dict_path) + print( + "DONE\n===================- End CKPT Conversion ====================\n" + ) diff --git a/wenet/LLM/template.py b/wenet/LLM/template.py new file mode 100644 index 0000000000..b0e0a95418 --- /dev/null +++ b/wenet/LLM/template.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class Template: + # one turn :{system_format}{user_format}{assistant_format} + # multi turns: + # {system_format}{user_format}{assistant_format}{user_format}{assistant_format}... + system: Optional[str] + user: str + assistant: str + + bos: str + eos: str + + +gemma = Template( + '', + 'user\n{content}\nmodel\n', + '{content}\n', + '', + '', +) + +llama3 = Template( + '<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>', + '<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n', + '{content}<|eot_id|>', + '<|begin_of_text|>', + '<|end_of_text|>', +) +WENET_LLM_Template = { + "gemma": gemma, + 'llama3': llama3, +} diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index f8a11b9229..3779b74eca 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -171,6 +171,11 @@ def get_args(): default=0.0, help='''The higher the score, the greater the degree of bias using decoding-graph for biasing''') + + parser.add_argument('--use_lora', + type=bool, + default=False, + help='''Whether to use lora for biasing''') args = parser.parse_args() print(args) return args diff --git a/wenet/bin/train.py b/wenet/bin/train.py index f772ff85fa..60d6374f4e 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -24,29 +24,33 @@ import torch.distributed as dist from torch.distributed.elastic.multiprocessing.errors import record +from wenet.utils.common import lrs_to_str from wenet.utils.executor import Executor from wenet.utils.config import override_config from wenet.utils.init_model import init_model from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.train_utils import ( - add_model_args, add_dataset_args, add_ddp_args, add_deepspeed_args, - add_trace_args, init_distributed, init_dataset_and_dataloader, - check_modify_and_save_config, init_optimizer_and_scheduler, - trace_and_print_model, wrap_cuda_model, init_summarywriter, save_model, - log_per_epoch) + add_fsdp_args, add_model_args, add_dataset_args, add_ddp_args, + add_deepspeed_args, add_trace_args, init_distributed, + init_dataset_and_dataloader, check_modify_and_save_config, + init_optimizer_and_scheduler, init_scaler, trace_and_print_model, + wrap_cuda_model, init_summarywriter, save_model, log_per_epoch, + add_lora_args) def get_args(): parser = argparse.ArgumentParser(description='training your network') parser.add_argument('--train_engine', default='torch_ddp', - choices=['torch_ddp', 'deepspeed'], + choices=['torch_ddp', 'torch_fsdp', 'deepspeed'], help='Engine for paralleled training') parser = add_model_args(parser) parser = add_dataset_args(parser) parser = add_ddp_args(parser) + parser = add_lora_args(parser) parser = add_deepspeed_args(parser) + parser = add_fsdp_args(parser) parser = add_trace_args(parser) args = parser.parse_args() if args.train_engine == "deepspeed": @@ -96,7 +100,7 @@ def main(): writer = init_summarywriter(args) # Dispatch model from cpu to gpu - model, device = wrap_cuda_model(args, model) + model, device = wrap_cuda_model(args, model, configs) # Get optimizer & scheduler model, optimizer, scheduler = init_optimizer_and_scheduler( @@ -114,13 +118,10 @@ def main(): # Get executor tag = configs["init_infos"].get("tag", "init") - executor = Executor(global_step=configs["init_infos"].get('step', -1) + - int("step_" in tag)) + executor = Executor(global_step=configs["init_infos"].get('step', -1)) # Init scaler, used for pytorch amp mixed precision training - scaler = None - if args.use_amp: - scaler = torch.cuda.amp.GradScaler() + scaler = init_scaler(args) # Start training loop start_epoch = configs["init_infos"].get('epoch', 0) + int("epoch_" in tag) @@ -133,9 +134,9 @@ def main(): for epoch in range(start_epoch, end_epoch): configs['epoch'] = epoch - lr = optimizer.param_groups[0]['lr'] - logging.info('Epoch {} TRAIN info lr {} rank {}'.format( - epoch, lr, rank)) + lrs = [group['lr'] for group in optimizer.param_groups] + logging.info('Epoch {} Step {} TRAIN info lr {} rank {}'.format( + epoch, executor.step, lrs_to_str(lrs), rank)) dist.barrier( ) # NOTE(xcsong): Ensure all ranks start Train at the same time. @@ -149,19 +150,16 @@ def main(): dist.barrier( ) # NOTE(xcsong): Ensure all ranks start CV at the same time. loss_dict = executor.cv(model, cv_data_loader, configs) - - lr = optimizer.param_groups[0]['lr'] - logging.info('Epoch {} CV info lr {} cv_loss {} rank {} acc {}'.format( - epoch, lr, loss_dict["loss"], rank, loss_dict["acc"])) info_dict = { 'epoch': epoch, - 'lr': lr, + 'lrs': [group['lr'] for group in optimizer.param_groups], 'step': executor.step, 'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), 'tag': "epoch_{}".format(epoch), 'loss_dict': loss_dict, **configs } + # epoch cv: tensorboard && log log_per_epoch(writer, info_dict=info_dict) save_model(model, info_dict=info_dict) @@ -173,6 +171,7 @@ def main(): final_model_path) else None os.symlink('{}.pt'.format(final_epoch), final_model_path) writer.close() + dist.destroy_process_group() if __name__ == '__main__': diff --git a/wenet/branchformer/encoder.py b/wenet/branchformer/encoder.py index 7d00b2a70b..2feda978e1 100644 --- a/wenet/branchformer/encoder.py +++ b/wenet/branchformer/encoder.py @@ -16,21 +16,17 @@ """Encoder definition.""" import torch -import torch.nn as nn -from typing import List, Optional, Tuple, Union + +from typing import List, Optional, Union from wenet.branchformer.encoder_layer import BranchformerEncoderLayer from wenet.branchformer.cgmlp import ConvolutionalGatingMLP -from wenet.utils.mask import make_pad_mask -from wenet.utils.mask import add_optional_chunk_mask +from wenet.transformer.encoder import BaseEncoder from wenet.utils.class_utils import ( - WENET_ATTENTION_CLASSES, - WENET_EMB_CLASSES, - WENET_SUBSAMPLE_CLASSES, -) + WENET_ATTENTION_CLASSES, ) -class BranchformerEncoder(nn.Module): +class BranchformerEncoder(BaseEncoder): """Branchformer encoder module.""" def __init__( @@ -39,7 +35,7 @@ def __init__( output_size: int = 256, use_attn: bool = True, attention_heads: int = 4, - attention_layer_type: str = "rel_selfattn", + selfattention_layer_type: str = "rel_selfattn", pos_enc_layer_type: str = "rel_pos", use_cgmlp: bool = True, cgmlp_linear_units: int = 2048, @@ -53,30 +49,41 @@ def __init__( dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, - input_layer: Optional[str] = "conv2d", - padding_idx: int = -1, + input_layer: str = "conv2d", stochastic_depth_rate: Union[float, List[float]] = 0.0, static_chunk_size: int = 0, use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, causal: bool = False, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, ): - super().__init__() - self._output_size = output_size - - self.embed = WENET_SUBSAMPLE_CLASSES[input_layer]( - input_size, - output_size, - dropout_rate, - WENET_EMB_CLASSES[pos_enc_layer_type](output_size, - positional_dropout_rate), - ) + super().__init__(input_size, output_size, attention_heads, + cgmlp_linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, True, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing, + use_sdpa, layer_norm_type, norm_eps) encoder_selfattn_layer_args = ( attention_heads, output_size, attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, ) cgmlp_layer = ConvolutionalGatingMLP @@ -87,6 +94,7 @@ def __init__( dropout_rate, use_linear_after_conv, gate_activation, + causal, ) if isinstance(stochastic_depth_rate, float): @@ -110,221 +118,60 @@ def __init__( f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) " f"should be equal to num_blocks ({num_blocks})") - self.encoders = torch.nn.ModuleList([ - BranchformerEncoderLayer( - output_size, WENET_ATTENTION_CLASSES[attention_layer_type]( - *encoder_selfattn_layer_args) if use_attn else None, - cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, - dropout_rate, merge_method, cgmlp_weight[lnum], - attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum]) - for lnum in range(num_blocks) - ]) - self.after_norm = nn.LayerNorm(output_size) - self.static_chunk_size = static_chunk_size - self.global_cmvn = global_cmvn - self.use_dynamic_chunk = use_dynamic_chunk - self.use_dynamic_left_chunk = use_dynamic_left_chunk - - def output_size(self) -> int: - return self._output_size - - def forward( - self, - xs: torch.Tensor, - ilens: torch.Tensor, - decoding_chunk_size: int = 0, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate forward propagation. - - Args: - xs (torch.Tensor): Input tensor (B, T, D). - ilens (torch.Tensor): Input length (#batch). - decoding_chunk_size: decoding chunk size for dynamic chunk - 0: default for training, use random dynamic chunk. - <0: for decoding, use full chunk. - >0: for decoding, use fixed chunk size as set. - num_decoding_left_chunks: number of left chunks, this is for decoding, - the chunk size is decoding_chunk_size. - >=0: use num_decoding_left_chunks - <0: use all left chunks - - Returns: - encoder output tensor xs, and subsampled masks - xs: padded output tensor (B, T' ~= T/subsample_rate, D) - masks: torch.Tensor batch padding mask after subsample - (B, 1, T' ~= T/subsample_rate) - """ - - T = xs.size(1) - masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - xs, pos_emb, masks = self.embed(xs, masks) - mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size, - num_decoding_left_chunks) - for layer in self.encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) - - xs = self.after_norm(xs) - # Here we assume the mask is not changed in encoder layers, so just - # return the masks before encoder layers, and the masks will be used - # for cross attention with decoder later - return xs, masks - - def forward_chunk( - self, - xs: torch.Tensor, - offset: int, - required_cache_size: int, - att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ Forward just one chunk - - Args: - xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), - where `time == (chunk_size - 1) * subsample_rate + \ - subsample.right_context + 1` - offset (int): current offset in encoder output time stamp - required_cache_size (int): cache size required for next chunk - compuation - >=0: actual cache size - <0: means all history cache is required - att_cache (torch.Tensor): cache tensor for KEY & VALUE in - transformer/conformer attention, with shape - (elayers, head, cache_t1, d_k * 2), where - `head * d_k == hidden-dim` and - `cache_t1 == chunk_size * num_decoding_left_chunks`. - cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, - (elayers, b=1, hidden-dim, cache_t2), where - `cache_t2 == cnn.lorder - 1` - - Returns: - torch.Tensor: output of current input xs, - with shape (b=1, chunk_size, hidden-dim). - torch.Tensor: new attention cache required for next chunk, with - dynamic shape (elayers, head, ?, d_k * 2) - depending on required_cache_size. - torch.Tensor: new conformer cnn cache required for next chunk, with - same shape as the original cnn_cache. - - """ - assert xs.size(0) == 1 - # tmp_masks is just for interface compatibility - tmp_masks = torch.ones(1, - xs.size(1), - device=xs.device, - dtype=torch.bool) - tmp_masks = tmp_masks.unsqueeze(1) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) - xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) - # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) - elayers, cache_t1 = att_cache.size(0), att_cache.size(2) - chunk_size = xs.size(1) - attention_key_size = cache_t1 + chunk_size - pos_emb = self.embed.position_encoding(offset=offset - cache_t1, - size=attention_key_size) - if required_cache_size < 0: - next_cache_start = 0 - elif required_cache_size == 0: - next_cache_start = attention_key_size - else: - next_cache_start = max(attention_key_size - required_cache_size, 0) - r_att_cache = [] - r_cnn_cache = [] - for i, layer in enumerate(self.encoders): - # NOTE(xcsong): Before layer.forward - # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), - # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) - xs, _, new_att_cache, new_cnn_cache = layer( - xs, - att_mask, - pos_emb, - att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, - cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) - # NOTE(xcsong): After layer.forward - # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), - # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) - r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) - r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) - - xs = self.after_norm(xs) - - # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), - # ? may be larger than cache_t1, it depends on required_cache_size - r_att_cache = torch.cat(r_att_cache, dim=0) - # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) - r_cnn_cache = torch.cat(r_cnn_cache, dim=0) - - return (xs, r_att_cache, r_cnn_cache) - - def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - decoding_chunk_size: int, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Forward input chunk by chunk with chunk_size like a streaming - fashion - - Here we should pay special attention to computation cache in the - streaming style forward chunk by chunk. Three things should be taken - into account for computation in the current network: - 1. transformer/conformer encoder layers output cache - 2. convolution in conformer - 3. convolution in subsampling - - However, we don't implement subsampling cache for: - 1. We can control subsampling module to output the right result by - overlapping input instead of cache left context, even though it - wastes some computation, but subsampling only takes a very - small fraction of computation in the whole model. - 2. Typically, there are several covolution layers with subsampling - in subsampling module, it is tricky and complicated to do cache - with different convolution layers with different subsampling - rate. - 3. Currently, nn.Sequential is used to stack all the convolution - layers in subsampling, we need to rewrite it to make it work - with cache, which is not prefered. - Args: - xs (torch.Tensor): (1, max_len, dim) - chunk_size (int): decoding chunk size - """ - assert decoding_chunk_size > 0 - # The model is trained by static or dynamic chunk - assert self.static_chunk_size > 0 or self.use_dynamic_chunk - subsampling = self.embed.subsampling_rate - context = self.embed.right_context + 1 # Add current frame - stride = subsampling * decoding_chunk_size - decoding_window = (decoding_chunk_size - 1) * subsampling + context - num_frames = xs.size(1) - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) - cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) - outputs = [] - offset = 0 - required_cache_size = decoding_chunk_size * num_decoding_left_chunks - - # Feed forward overlap input step by step - for cur in range(0, num_frames - context + 1, stride): - end = min(cur + decoding_window, num_frames) - chunk_xs = xs[:, cur:end, :] - (y, att_cache, - cnn_cache) = self.forward_chunk(chunk_xs, offset, - required_cache_size, att_cache, - cnn_cache) - outputs.append(y) - offset += y.size(1) - ys = torch.cat(outputs, 1) - masks = torch.ones((1, 1, ys.size(1)), - device=ys.device, - dtype=torch.bool) - return ys, masks + self.encoders = LayerDropModuleList( + p=stochastic_depth_rate, + modules=[ + BranchformerEncoderLayer( + output_size, + WENET_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args) if use_attn else None, + cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, + dropout_rate, + merge_method, + cgmlp_weight[lnum], + attn_branch_drop_rate[lnum], + stochastic_depth_rate[lnum], + ) for lnum in range(num_blocks) + ]) + + +# modify from : https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/layer_drop.py # noqa +class LayerDropModuleList(torch.nn.ModuleList): + """ + A LayerDrop implementation based on :class:`torch.nn.ModuleList`. + + We refresh the choice of which layers to drop every time we iterate + over the LayerDropModuleList instance. During evaluation we always + iterate over all layers. + + Usage:: + + layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) + for layer in layers: # this might iterate over layers 1 and 3 + x = layer(x) + for layer in layers: # this might iterate over all layers + x = layer(x) + for layer in layers: # this might not iterate over any layers + x = layer(x) + + Args: + p (float): probability of dropping out each layer + modules (iterable, optional): an iterable of modules to add + + Limitations: + 1 can work with ddp when layer's gradient checkpoint disabled + 2 can't work with ddp when layer's gradient checkpoint enables + 3 can work with fsdp + 4 can work with deepspeed + """ + + def __init__(self, p: List[float], modules=None): + super().__init__(modules) + assert len(p) == len(self) + self.p = p + + def __iter__(self): + dropout_probs = torch.empty(len(self)).uniform_() + for i, m in enumerate(super().__iter__()): + if not self.training or (dropout_probs[i] > self.p[i]): + yield m diff --git a/wenet/branchformer/encoder_layer.py b/wenet/branchformer/encoder_layer.py index 9654a24059..a48feefbd1 100644 --- a/wenet/branchformer/encoder_layer.py +++ b/wenet/branchformer/encoder_layer.py @@ -19,6 +19,8 @@ import torch.nn as nn from typing import Optional, Tuple +from wenet.transformer.attention import T_CACHE + class BranchformerEncoderLayer(torch.nn.Module): """Branchformer encoder layer module. @@ -106,48 +108,17 @@ def __init__( else: self.merge_proj = torch.nn.Identity() - def forward( + def _forward( self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + att_cache: T_CACHE = (torch.zeros( + (0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)), cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute encoded features. - - Args: - x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size). - mask (torch.Tensor): Mask tensor for the input (#batch, time, time). - pos_emb (torch.Tensor): positional encoding, must not be None - for BranchformerEncoderLayer. - mask_pad (torch.Tensor): batch padding mask used for conv module. - (#batch, 1,time), (0, 0, 0) means fake mask. - att_cache (torch.Tensor): Cache tensor of the KEY & VALUE - (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. - cnn_cache (torch.Tensor): Convolution cache in cgmlp layer - (#batch=1, size, cache_t2) - - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time, time. - torch.Tensor: att_cache tensor, - (#batch=1, head, cache_t1 + time, d_k * 2). - torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). - """ - - stoch_layer_coeff = 1.0 - skip_layer = False - # with stochastic depth, residual connection `x + f(x)` becomes - # `x <- x + 1 / (1 - p) * f(x)` at training time. - if self.training and self.stochastic_depth_rate > 0: - skip_layer = torch.rand(1).item() < self.stochastic_depth_rate - stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) - - if skip_layer: - return x, mask, att_cache, cnn_cache - + stoch_layer_coeff: float = 1.0 + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]: # Two branches x1 = x x2 = x @@ -232,3 +203,43 @@ def forward( x = self.norm_final(x) return x, mask, new_att_cache, new_cnn_cache + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: T_CACHE = (torch.zeros( + (0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]: + """Compute encoded features. + + Args: + x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time, time). + pos_emb (torch.Tensor): positional encoding, must not be None + for BranchformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in cgmlp layer + (#batch=1, size, cache_t2) + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time. + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + stoch_layer_coeff = 1.0 + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + if self.training: + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + return self._forward(x, mask, pos_emb, mask_pad, att_cache, cnn_cache, + stoch_layer_coeff) diff --git a/wenet/ctl_model/asr_model_ctl.py b/wenet/ctl_model/asr_model_ctl.py index c5457e590a..6e9bc810a7 100644 --- a/wenet/ctl_model/asr_model_ctl.py +++ b/wenet/ctl_model/asr_model_ctl.py @@ -67,7 +67,7 @@ def __init__( self.ctl_weight = ctl_weight self.logit_temp = logit_temp - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward( self, batch: dict, diff --git a/wenet/ctl_model/encoder.py b/wenet/ctl_model/encoder.py index 6b71d0cf83..9aa18b7048 100644 --- a/wenet/ctl_model/encoder.py +++ b/wenet/ctl_model/encoder.py @@ -15,12 +15,11 @@ # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Encoder definition.""" -from typing import Tuple +from typing import Optional, Tuple import torch from wenet.utils.mask import make_pad_mask -from wenet.utils.mask import add_optional_chunk_mask from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder @@ -44,6 +43,21 @@ def __init__( use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + activation_type: str = "relu", + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + selfattention_layer_type: str = "selfattn", + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): """ Construct DualTransformerEncoder Support both the full context mode and the streaming mode separately @@ -53,56 +67,11 @@ def __init__( positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, static_chunk_size, use_dynamic_chunk, global_cmvn, - use_dynamic_left_chunk) - - def forward( - self, - xs: torch.Tensor, - xs_lens: torch.Tensor, - decoding_chunk_size: int = 0, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Embed positions in tensor. - - Args: - xs: padded input tensor (B, T, D) - xs_lens: input length (B) - decoding_chunk_size: decoding chunk size for dynamic chunk - 0: default for training, use random dynamic chunk. - <0: for decoding, use full chunk. - >0: for decoding, use fixed chunk size as set. - num_decoding_left_chunks: number of left chunks, this is for decoding, - the chunk size is decoding_chunk_size. - >=0: use num_decoding_left_chunks - <0: use all left chunks - Returns: - encoder output tensor xs, and subsampled masks - xs: padded output tensor (B, T' ~= T/subsample_rate, D) - masks: torch.Tensor batch padding mask after subsample - (B, 1, T' ~= T/subsample_rate) - """ - T = xs.size(1) - masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - xs, pos_emb, masks = self.embed(xs, masks) - mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, - masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size, - num_decoding_left_chunks, - enable_full_context=False) - for layer in self.encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) - if self.normalize_before: - xs = self.after_norm(xs) - # Here we assume the mask is not changed in encoder layers, so just - # return the masks before encoder layers, and the masks will be used - # for cross attention with decoder later - return xs, masks + use_dynamic_left_chunk, query_bias, key_bias, + value_bias, activation_type, gradient_checkpointing, + use_sdpa, layer_norm_type, norm_eps, n_kv_head, + head_dim, selfattention_layer_type, mlp_type, + mlp_bias, n_expert, n_expert_activated) def forward_full( self, @@ -152,68 +121,36 @@ def __init__( cnn_module_kernel: int = 15, causal: bool = False, cnn_module_norm: str = "batch_norm", + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + conv_bias: bool = True, + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): """ Construct DualConformerEncoder Support both the full context mode and the streaming mode separately """ - super().__init__(input_size, output_size, attention_heads, - linear_units, num_blocks, dropout_rate, - positional_dropout_rate, attention_dropout_rate, - input_layer, pos_enc_layer_type, normalize_before, - static_chunk_size, use_dynamic_chunk, global_cmvn, - use_dynamic_left_chunk, positionwise_conv_kernel_size, - macaron_style, selfattention_layer_type, - activation_type, use_cnn_module, cnn_module_kernel, - causal, cnn_module_norm) - - def forward( - self, - xs: torch.Tensor, - xs_lens: torch.Tensor, - decoding_chunk_size: int = 0, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Embed positions in tensor. - - Args: - xs: padded input tensor (B, T, D) - xs_lens: input length (B) - decoding_chunk_size: decoding chunk size for dynamic chunk - 0: default for training, use random dynamic chunk. - <0: for decoding, use full chunk. - >0: for decoding, use fixed chunk size as set. - num_decoding_left_chunks: number of left chunks, this is for decoding, - the chunk size is decoding_chunk_size. - >=0: use num_decoding_left_chunks - <0: use all left chunks - Returns: - encoder output tensor xs, and subsampled masks - xs: padded output tensor (B, T' ~= T/subsample_rate, D) - masks: torch.Tensor batch padding mask after subsample - (B, 1, T' ~= T/subsample_rate) - """ - T = xs.size(1) - masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - xs, pos_emb, masks = self.embed(xs, masks) - mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, - masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size, - num_decoding_left_chunks, - enable_full_context=False) - for layer in self.encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) - if self.normalize_before: - xs = self.after_norm(xs) - # Here we assume the mask is not changed in encoder layers, so just - # return the masks before encoder layers, and the masks will be used - # for cross attention with decoder later - return xs, masks + super().__init__( + input_size, output_size, attention_heads, linear_units, num_blocks, + dropout_rate, positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, positionwise_conv_kernel_size, + macaron_style, selfattention_layer_type, activation_type, + use_cnn_module, cnn_module_kernel, causal, cnn_module_norm, + query_bias, key_bias, value_bias, conv_bias, + gradient_checkpointing, use_sdpa, layer_norm_type, norm_eps, + n_kv_head, head_dim, mlp_type, mlp_bias, n_expert, + n_expert_activated) def forward_full( self, diff --git a/wenet/dataset/datapipes.py b/wenet/dataset/datapipes.py index 6d89ab5522..0d2fdca38e 100644 --- a/wenet/dataset/datapipes.py +++ b/wenet/dataset/datapipes.py @@ -184,21 +184,24 @@ def __iter__(self): @functional_datapipe("dynamic_batch") class DynamicBatchDataPipe(IterDataPipe): - def __init__(self, dataset: IterDataPipe, window_class, - wrapper_class) -> None: + def __init__(self, dataset: IterDataPipe, window_class, wrapper_class, + elem_size_fn) -> None: _check_unpickable_fn(window_class) _check_unpickable_fn(wrapper_class) + _check_unpickable_fn(elem_size_fn) super().__init__() self.dp = dataset assert window_class is not None assert wrapper_class is not None + self.elem_size_fn = elem_size_fn self.window_class = window_class self._buffer = [] self._wrappr_class = wrapper_class def __iter__(self): for elem in self.dp: - if not self.window_class(elem, len(self._buffer)): + if not self.window_class(self.elem_size_fn(elem), len( + self._buffer)): self._buffer.append(elem) else: if len(self._buffer) > 0: diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 280d66169a..d88ae8ad0c 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -66,7 +66,8 @@ def Dataset(data_type, dataset = dataset.map_ignore_error(processor.decode_wav) singal_channel_conf = conf.get('singal_channel_conf', {}) - dataset = dataset.map(partial(processor.singal_channel, **singal_channel_conf)) + dataset = dataset.map( + partial(processor.singal_channel, **singal_channel_conf)) speaker_conf = conf.get('speaker_conf', None) if speaker_conf is not None: @@ -149,6 +150,7 @@ def Dataset(data_type, dataset = dataset.dynamic_batch( processor.DynamicBatchWindow(max_frames_in_batch), wrapper_class=processor.padding, + elem_size_fn=processor.feats_length_fn, ) return dataset diff --git a/wenet/dataset/llm_dataset.py b/wenet/dataset/llm_dataset.py new file mode 100644 index 0000000000..dd5f323d59 --- /dev/null +++ b/wenet/dataset/llm_dataset.py @@ -0,0 +1,116 @@ +from functools import partial +import sys +from wenet.LLM.template import WENET_LLM_Template +from wenet.dataset.datapipes import (WenetRawDatasetSource) +from wenet.dataset import (processor, llm_processor) +from wenet.text.base_tokenizer import BaseTokenizer +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer + + +def Dataset(data_type, + data_list_file, + tokenizer: BaseTokenizer, + conf=None, + partition=True): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + tokenizer (BaseTokenizer or None): tokenizer to tokenize + partition(bool): whether to do data partition in terms of rank + """ + assert conf is not None + assert data_type in ['raw', 'shard'] + # cycle dataset + cycle = conf.get('cycle', 1) + # stage1 shuffle: source + list_shuffle = conf.get('list_shuffle', True) + list_shuffle_size = sys.maxsize + if list_shuffle: + list_shuffle_conf = conf.get('list_shuffle_conf', {}) + list_shuffle_size = list_shuffle_conf.get('shuffle_size', + list_shuffle_size) + if data_type == 'raw': + dataset = WenetRawDatasetSource(data_list_file, + partition=partition, + shuffle=list_shuffle, + shuffle_size=list_shuffle_size, + cycle=cycle) + dataset = dataset.map(processor.parse_json) + + else: + raise NotImplementedError('only support jsonl for now') + + # TODO: DPO etc + data_style = conf.get('style', 'sft') + assert data_style in ['pretrain', 'sft'] + assert isinstance(tokenizer, HuggingFaceTokenizer) + style_conf = conf.get('data_style_conf', {}) + template = WENET_LLM_Template[style_conf.get('template', 'gemma')] + if data_style == 'sft': + dataset = dataset.map( + partial( + llm_processor.parse_sft, + tokenizer=tokenizer, + template=template, + add_bos=style_conf.get('add_bos', True), + add_eos=style_conf.get('add_eos', True), + )) + else: + dataset = dataset.map( + partial( + llm_processor.parse_pretrain, + tokenizer=tokenizer, + template=template, + add_bos=style_conf.get('add_bos', True), + add_eos=style_conf.get('add_eos', True), + )) + shuffle = conf.get('shuffle', True) + if shuffle: + shuffle_conf = conf.get('shuffle_conf', {}) + dataset = dataset.shuffle(buffer_size=shuffle_conf['shuffle_size']) + + sort = conf.get('sort', True) + if sort: + sort_conf = conf.get('sort_conf', {}) + dataset = dataset.sort(buffer_size=sort_conf['sort_size'], + key_func=llm_processor.sort_by_input) + shift = conf.get('shift', True) + if shift: + dataset = dataset.map(llm_processor.shift) + + filter_conf = conf.get('filter_conf', {}) + dataset = dataset.filter(partial(llm_processor.filter, **filter_conf)) + + batch_conf = conf.get('batch_conf', {}) + batch_type = batch_conf.get('batch_type', 'static') + assert batch_type in ['static', 'bucket', 'dynamic'] + if batch_type == 'static': + assert 'batch_size' in batch_conf + batch_size = batch_conf.get('batch_size', 16) + dataset = dataset.batch( + batch_size, + wrapper_class=llm_processor.padding, + ) + elif batch_type == 'bucket': + assert 'bucket_boundaries' in batch_conf + assert 'bucket_batch_sizes' in batch_conf + dataset = dataset.bucket_by_sequence_length( + llm_processor.input_length_fn, + batch_conf['bucket_boundaries'], + batch_conf['bucket_batch_sizes'], + wrapper_class=llm_processor.padding, + ) + else: + max_tokens_in_batch = batch_conf.get('max_tokens_in_batch', 50000) + dataset = dataset.dynamic_batch( + processor.DynamicBatchWindow(max_tokens_in_batch), + wrapper_class=llm_processor.padding, + elem_size_fn=llm_processor.input_length_fn, + ) + + return dataset diff --git a/wenet/dataset/llm_processor.py b/wenet/dataset/llm_processor.py new file mode 100644 index 0000000000..f88bdcdc5e --- /dev/null +++ b/wenet/dataset/llm_processor.py @@ -0,0 +1,173 @@ +from typing import Dict, List + +import torch +from torch.nn.utils.rnn import pad_sequence +from wenet.LLM.template import Template +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer +from wenet.utils.common import IGNORE_ID + + +def parse_sft( + sample, + tokenizer: HuggingFaceTokenizer, + template: Template, + add_bos: bool = True, + add_eos: bool = True, +): + """Paser sft json line to tensor + + Args: + sample: + { + 'system': 'you are a helpful ...', + "conversation": [{ + 'human': '...', + 'assistant': '...' + }] + } + + Returns: + {input_ids, output_ids} + """ + chat_pattern = template + input_ids = [] + output_ids = [] + system_text = sample.get('system', '') + if chat_pattern.system is not None: + system_text = chat_pattern.system.format(content=system_text) + if add_bos: + system_text = template.bos + system_text + _, system_text_ids = tokenizer.tokenize(system_text) + input_ids += system_text_ids + output_ids += [IGNORE_ID] * len(system_text_ids) + conversations = sample['conversation'] + assert isinstance(conversations, List) + for conversation in conversations: + human = conversation['human'] + human = chat_pattern.user.format(content=human) + _, human_ids = tokenizer.tokenize(human) + input_ids += human_ids + output_ids += [IGNORE_ID] * len(human_ids) + if 'assistant' in conversation: + assistant = conversation['assistant'] + assistant = chat_pattern.assistant.format(content=assistant) + _, assistant_ids = tokenizer.tokenize(assistant) + input_ids += assistant_ids + output_ids += assistant_ids + + if add_eos: + eos_id = tokenizer.tokens2ids([template.eos]) + input_ids += eos_id + output_ids += eos_id + + assert len(input_ids) == len(output_ids) + return { + 'input_ids': torch.tensor(input_ids), + 'output_ids': torch.tensor(output_ids), + } + + +def parse_pretrain(sample, + tokenizer: HuggingFaceTokenizer, + template: Template, + add_bos: bool = True, + add_eos: bool = False): + """ Parse text from json line + + Args: + sample: str, str is a json line has txt + + Returns: + {input_ids, output_ids} + """ + assert 'text' in sample + text = sample['text'] + _, input_ids = tokenizer.tokenize(text) + if add_bos: + input_ids = [template.bos] + input_ids + if add_eos: + input_ids = input_ids + [template.eos] + + return { + 'input_ids': torch.tensor(input_ids), + 'output_ids': torch.tensor(input_ids), + } + + +def shift(sample): + input_ids = sample['input_ids'] + output_ids = sample['output_ids'] + + sample['input_ids'] = input_ids[:-1] + sample['output_ids'] = output_ids[1:] + return sample + + +def filter(sample, token_max_length: int = 8190, token_min_length=1): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + sample: {input_ids, output_ids} + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + + Returns: + bool: True to keep, False to filter + """ + assert 'input_ids' in sample + assert 'output_ids' in sample + assert isinstance(sample['input_ids'], torch.Tensor) + assert isinstance(sample['output_ids'], torch.Tensor) + if sample['input_ids'].size(0) < token_min_length: + return False + if sample['input_ids'].size(0) > token_max_length: + return False + return True + + +def sort_by_input(sample): + assert 'input_ids' in sample + assert isinstance(sample['input_ids'], torch.Tensor) + return sample['input_ids'].size(0) + + +def input_length_fn(sample) -> int: + assert 'input_ids' in sample + return sample['input_ids'].size(0) + + +def padding(data: List[Dict]): + """ Padding the data into training data + + Args: + data: List[{input_ids, output_ids} + + Returns: + Tuple(feats, labels) + """ + sample = data + feats_length = torch.tensor([x['input_ids'].size(0) for x in sample], + dtype=torch.int32) + order = torch.argsort(feats_length, descending=True) + feats_lengths = torch.tensor( + [sample[i]['input_ids'].size(0) for i in order], dtype=torch.int32) + sorted_feats = [sample[i]['input_ids'] for i in order] + sorted_labels = [sample[i]['output_ids'] for i in order] + padded_feats = pad_sequence(sorted_feats, + batch_first=True, + padding_value=0) + padding_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=IGNORE_ID) + + batch = { + 'feats': padded_feats, + "target": padding_labels, + "feats_lengths": feats_lengths, + "target_lengths": feats_lengths, + } + return batch diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 8f0008d437..77ddd4bd7f 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -141,6 +141,7 @@ def decode_wav(sample): sample['sample_rate'] = sample_rate return sample + def singal_channel(sample, channel=0): """ Choose a channel of sample. Inplace operation. @@ -569,11 +570,8 @@ def __init__(self, max_frames_in_batch=12000): self.longest_frames = 0 self.max_frames_in_batch = max_frames_in_batch - def __call__(self, sample, buffer_size): - assert isinstance(sample, dict) - assert 'feat' in sample - assert isinstance(sample['feat'], torch.Tensor) - new_sample_frames = sample['feat'].size(0) + def __call__(self, elem_size, buffer_size): + new_sample_frames = elem_size self.longest_frames = max(self.longest_frames, new_sample_frames) frames_after_padding = self.longest_frames * (buffer_size + 1) if frames_after_padding > self.max_frames_in_batch: diff --git a/wenet/e_branchformer/encoder.py b/wenet/e_branchformer/encoder.py index 2d4c6097e8..c298e473bb 100644 --- a/wenet/e_branchformer/encoder.py +++ b/wenet/e_branchformer/encoder.py @@ -17,23 +17,20 @@ """Encoder definition.""" import torch -import torch.nn as nn -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union +from wenet.branchformer.encoder import LayerDropModuleList from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer from wenet.branchformer.cgmlp import ConvolutionalGatingMLP -from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward -from wenet.utils.mask import make_pad_mask -from wenet.utils.mask import add_optional_chunk_mask +from wenet.transformer.encoder import ConformerEncoder from wenet.utils.class_utils import ( - WENET_ATTENTION_CLASSES, - WENET_EMB_CLASSES, - WENET_SUBSAMPLE_CLASSES, WENET_ACTIVATION_CLASSES, + WENET_ATTENTION_CLASSES, + WENET_MLP_CLASSES, ) -class EBranchformerEncoder(nn.Module): +class EBranchformerEncoder(ConformerEncoder): """E-Branchformer encoder module.""" def __init__( @@ -42,20 +39,18 @@ def __init__( output_size: int = 256, attention_heads: int = 4, linear_units: int = 2048, - attention_layer_type: str = "rel_selfattn", + selfattention_layer_type: str = "rel_selfattn", pos_enc_layer_type: str = "rel_pos", activation_type: str = "swish", cgmlp_linear_units: int = 2048, cgmlp_conv_kernel: int = 31, use_linear_after_conv: bool = False, gate_activation: str = "identity", - merge_method: str = "concat", num_blocks: int = 12, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, - input_layer: Optional[str] = "conv2d", - padding_idx: int = -1, + input_layer: str = "conv2d", stochastic_depth_rate: Union[float, List[float]] = 0.0, static_chunk_size: int = 0, use_dynamic_chunk: bool = False, @@ -65,23 +60,65 @@ def __init__( merge_conv_kernel: int = 3, use_ffn: bool = True, macaron_style: bool = True, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + conv_bias: bool = True, + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): - super().__init__() - activation = WENET_ACTIVATION_CLASSES[activation_type]() - self._output_size = output_size - - self.embed = WENET_SUBSAMPLE_CLASSES[input_layer]( - input_size, - output_size, - dropout_rate, - WENET_EMB_CLASSES[pos_enc_layer_type](output_size, - positional_dropout_rate), - ) + super().__init__(input_size, + output_size, + attention_heads, + linear_units, + num_blocks, + dropout_rate, + positional_dropout_rate, + attention_dropout_rate, + input_layer, + pos_enc_layer_type, + True, + static_chunk_size, + use_dynamic_chunk, + global_cmvn, + use_dynamic_left_chunk, + 1, + macaron_style, + selfattention_layer_type, + activation_type, + query_bias=query_bias, + key_bias=key_bias, + value_bias=value_bias, + conv_bias=conv_bias, + gradient_checkpointing=gradient_checkpointing, + use_sdpa=use_sdpa, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + n_kv_head=n_kv_head, + head_dim=head_dim, + mlp_type=mlp_type, + mlp_bias=mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated) encoder_selfattn_layer_args = ( attention_heads, output_size, attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, ) cgmlp_layer = ConvolutionalGatingMLP @@ -90,12 +127,16 @@ def __init__( gate_activation, causal) # feed-forward module definition - positionwise_layer = PositionwiseFeedForward + mlp_class = WENET_MLP_CLASSES[mlp_type] + activation = WENET_ACTIVATION_CLASSES[activation_type]() positionwise_layer_args = ( output_size, linear_units, dropout_rate, activation, + mlp_bias, + n_expert, + n_expert_activated, ) if isinstance(stochastic_depth_rate, float): @@ -105,229 +146,20 @@ def __init__( f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " f"should be equal to num_blocks ({num_blocks})") - self.encoders = torch.nn.ModuleList([ - EBranchformerEncoderLayer( - output_size, - WENET_ATTENTION_CLASSES[attention_layer_type]( - *encoder_selfattn_layer_args), - cgmlp_layer(*cgmlp_layer_args), - positionwise_layer( - *positionwise_layer_args) if use_ffn else None, - positionwise_layer(*positionwise_layer_args) - if use_ffn and macaron_style else None, - dropout_rate, - merge_conv_kernel=merge_conv_kernel, - causal=causal, - stochastic_depth_rate=stochastic_depth_rate[lnum], - ) for lnum in range(num_blocks) - ]) - - self.after_norm = nn.LayerNorm(output_size) - self.static_chunk_size = static_chunk_size - self.global_cmvn = global_cmvn - self.use_dynamic_chunk = use_dynamic_chunk - self.use_dynamic_left_chunk = use_dynamic_left_chunk - - def output_size(self) -> int: - return self._output_size - - def forward( - self, - xs: torch.Tensor, - ilens: torch.Tensor, - decoding_chunk_size: int = 0, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate forward propagation. - - Args: - xs (torch.Tensor): Input tensor (B, T, D). - ilens (torch.Tensor): Input length (#batch). - decoding_chunk_size: decoding chunk size for dynamic chunk - 0: default for training, use random dynamic chunk. - <0: for decoding, use full chunk. - >0: for decoding, use fixed chunk size as set. - num_decoding_left_chunks: number of left chunks, this is for decoding, - the chunk size is decoding_chunk_size. - >=0: use num_decoding_left_chunks - <0: use all left chunks - - Returns: - encoder output tensor xs, and subsampled masks - xs: padded output tensor (B, T' ~= T/subsample_rate, D) - masks: torch.Tensor batch padding mask after subsample - (B, 1, T' ~= T/subsample_rate) - """ - - T = xs.size(1) - masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - xs, pos_emb, masks = self.embed(xs, masks) - mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size, - num_decoding_left_chunks) - for layer in self.encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) - - xs = self.after_norm(xs) - # Here we assume the mask is not changed in encoder layers, so just - # return the masks before encoder layers, and the masks will be used - # for cross attention with decoder later - return xs, masks - - def forward_chunk( - self, - xs: torch.Tensor, - offset: int, - required_cache_size: int, - att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ Forward just one chunk - - Args: - xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), - where `time == (chunk_size - 1) * subsample_rate + \ - subsample.right_context + 1` - offset (int): current offset in encoder output time stamp - required_cache_size (int): cache size required for next chunk - compuation - >=0: actual cache size - <0: means all history cache is required - att_cache (torch.Tensor): cache tensor for KEY & VALUE in - transformer/conformer attention, with shape - (elayers, head, cache_t1, d_k * 2), where - `head * d_k == hidden-dim` and - `cache_t1 == chunk_size * num_decoding_left_chunks`. - cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, - (elayers, b=1, hidden-dim, cache_t2), where - `cache_t2 == cnn.lorder - 1` - - Returns: - torch.Tensor: output of current input xs, - with shape (b=1, chunk_size, hidden-dim). - torch.Tensor: new attention cache required for next chunk, with - dynamic shape (elayers, head, ?, d_k * 2) - depending on required_cache_size. - torch.Tensor: new conformer cnn cache required for next chunk, with - same shape as the original cnn_cache. - - """ - assert xs.size(0) == 1 - # tmp_masks is just for interface compatibility - tmp_masks = torch.ones(1, - xs.size(1), - device=xs.device, - dtype=torch.bool) - tmp_masks = tmp_masks.unsqueeze(1) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) - xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) - # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) - elayers, cache_t1 = att_cache.size(0), att_cache.size(2) - chunk_size = xs.size(1) - attention_key_size = cache_t1 + chunk_size - pos_emb = self.embed.position_encoding(offset=offset - cache_t1, - size=attention_key_size) - if required_cache_size < 0: - next_cache_start = 0 - elif required_cache_size == 0: - next_cache_start = attention_key_size - else: - next_cache_start = max(attention_key_size - required_cache_size, 0) - r_att_cache = [] - r_cnn_cache = [] - for i, layer in enumerate(self.encoders): - # NOTE(xcsong): Before layer.forward - # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), - # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) - xs, _, new_att_cache, new_cnn_cache = layer( - xs, - att_mask, - pos_emb, - att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, - cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) - # NOTE(xcsong): After layer.forward - # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), - # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) - r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) - r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) - - xs = self.after_norm(xs) - - # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), - # ? may be larger than cache_t1, it depends on required_cache_size - r_att_cache = torch.cat(r_att_cache, dim=0) - # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) - r_cnn_cache = torch.cat(r_cnn_cache, dim=0) - - return (xs, r_att_cache, r_cnn_cache) - - def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - decoding_chunk_size: int, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Forward input chunk by chunk with chunk_size like a streaming - fashion - - Here we should pay special attention to computation cache in the - streaming style forward chunk by chunk. Three things should be taken - into account for computation in the current network: - 1. transformer/conformer encoder layers output cache - 2. convolution in conformer - 3. convolution in subsampling - - However, we don't implement subsampling cache for: - 1. We can control subsampling module to output the right result by - overlapping input instead of cache left context, even though it - wastes some computation, but subsampling only takes a very - small fraction of computation in the whole model. - 2. Typically, there are several covolution layers with subsampling - in subsampling module, it is tricky and complicated to do cache - with different convolution layers with different subsampling - rate. - 3. Currently, nn.Sequential is used to stack all the convolution - layers in subsampling, we need to rewrite it to make it work - with cache, which is not prefered. - Args: - xs (torch.Tensor): (1, max_len, dim) - chunk_size (int): decoding chunk size - """ - assert decoding_chunk_size > 0 - # The model is trained by static or dynamic chunk - assert self.static_chunk_size > 0 or self.use_dynamic_chunk - subsampling = self.embed.subsampling_rate - context = self.embed.right_context + 1 # Add current frame - stride = subsampling * decoding_chunk_size - decoding_window = (decoding_chunk_size - 1) * subsampling + context - num_frames = xs.size(1) - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) - cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) - outputs = [] - offset = 0 - required_cache_size = decoding_chunk_size * num_decoding_left_chunks - - # Feed forward overlap input step by step - for cur in range(0, num_frames - context + 1, stride): - end = min(cur + decoding_window, num_frames) - chunk_xs = xs[:, cur:end, :] - (y, att_cache, - cnn_cache) = self.forward_chunk(chunk_xs, offset, - required_cache_size, att_cache, - cnn_cache) - outputs.append(y) - offset += y.size(1) - ys = torch.cat(outputs, 1) - masks = torch.ones((1, 1, ys.size(1)), - device=ys.device, - dtype=torch.bool) - return ys, masks + self.encoders = LayerDropModuleList( + p=stochastic_depth_rate, + modules=[ + EBranchformerEncoderLayer( + output_size, + WENET_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args), + cgmlp_layer(*cgmlp_layer_args), + mlp_class(*positionwise_layer_args) if use_ffn else None, + mlp_class(*positionwise_layer_args) + if use_ffn and macaron_style else None, + dropout_rate, + merge_conv_kernel=merge_conv_kernel, + causal=causal, + stochastic_depth_rate=stochastic_depth_rate[lnum], + ) for lnum in range(num_blocks) + ]) diff --git a/wenet/e_branchformer/encoder_layer.py b/wenet/e_branchformer/encoder_layer.py index dba232383a..cad2ac4dab 100644 --- a/wenet/e_branchformer/encoder_layer.py +++ b/wenet/e_branchformer/encoder_layer.py @@ -20,6 +20,8 @@ import torch.nn as nn from typing import Optional, Tuple +from wenet.transformer.attention import T_CACHE + class EBranchformerEncoderLayer(torch.nn.Module): """E-Branchformer encoder layer module. @@ -88,47 +90,17 @@ def __init__( self.merge_proj = torch.nn.Linear(size + size, size) self.stochastic_depth_rate = stochastic_depth_rate - def forward( + def _forward( self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + att_cache: T_CACHE = (torch.zeros( + (0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)), cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute encoded features. - - Args: - x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size). - mask (torch.Tensor): Mask tensor for the input (#batch, time, time). - pos_emb (torch.Tensor): positional encoding, must not be None - for BranchformerEncoderLayer. - mask_pad (torch.Tensor): batch padding mask used for conv module. - (#batch, 1,time), (0, 0, 0) means fake mask. - att_cache (torch.Tensor): Cache tensor of the KEY & VALUE - (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. - cnn_cache (torch.Tensor): Convolution cache in cgmlp layer - (#batch=1, size, cache_t2) - - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time, time. - torch.Tensor: att_cache tensor, - (#batch=1, head, cache_t1 + time, d_k * 2). - torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). - """ - - stoch_layer_coeff = 1.0 - skip_layer = False - # with stochastic depth, residual connection `x + f(x)` becomes - # `x <- x + 1 / (1 - p) * f(x)` at training time. - if self.training and self.stochastic_depth_rate > 0: - skip_layer = torch.rand(1).item() < self.stochastic_depth_rate - stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) - - if skip_layer: - return x, mask, att_cache, cnn_cache + stoch_layer_coeff: float = 1.0 + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]: if self.feed_forward_macaron is not None: residual = x @@ -173,3 +145,43 @@ def forward( x = self.norm_final(x) return x, mask, new_att_cache, new_cnn_cache + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: T_CACHE = (torch.zeros( + (0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]: + """Compute encoded features. + + Args: + x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time, time). + pos_emb (torch.Tensor): positional encoding, must not be None + for BranchformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in cgmlp layer + (#batch=1, size, cache_t2) + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time. + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + stoch_layer_coeff = 1.0 + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + if self.training: + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + return self._forward(x, mask, pos_emb, mask_pad, att_cache, cnn_cache, + stoch_layer_coeff) diff --git a/wenet/finetune/lora/attention.py b/wenet/finetune/lora/attention.py new file mode 100644 index 0000000000..4cb880f9e9 --- /dev/null +++ b/wenet/finetune/lora/attention.py @@ -0,0 +1,123 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2024 Alan (alanfangemail@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-Head Attention layer definition with lora.""" + +from typing import Optional, List + +import torch +from torch import nn + +from wenet.transformer.attention import (MultiHeadedAttention, + RelPositionMultiHeadedAttention) +import wenet.finetune.lora.layers as lora + + +class LoRAMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with lora. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + use_sdpa: bool = False, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + lora_rank: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0, + lora_list: Optional[List[str]] = None): + """Construct an MultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias, + value_bias, use_sdpa) + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_out = lora.Linear( + n_feat, + n_feat, + r=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout + ) if lora_list and "o" in lora_list else nn.Linear(n_feat, n_feat) + + lora_qkv_dict = { + "q": lora_list and "q" in lora_list, + "k": lora_list and "k" in lora_list, + "v": lora_list and "v" in lora_list + } + bias_dict = {"q": query_bias, "k": key_bias, "v": value_bias} + + for key, value in lora_qkv_dict.items(): + setattr( + self, f"linear_{key}", + lora.Linear(n_feat, + n_feat, + r=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + bias=bias_dict[key]) if value else nn.Linear( + n_feat, n_feat, bias_dict[key])) + self.dropout = nn.Dropout(p=dropout_rate) + + +class LoRARelPositionMultiHeadedAttention(LoRAMultiHeadedAttention, + RelPositionMultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + use_sdpa: bool = False, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + lora_rank: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0, + lora_list: Optional[List[str]] = None): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias, + value_bias, use_sdpa, lora_rank, lora_alpha, + lora_dropout, lora_list) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) diff --git a/wenet/finetune/lora/encoder.py b/wenet/finetune/lora/encoder.py new file mode 100644 index 0000000000..2de1e8594b --- /dev/null +++ b/wenet/finetune/lora/encoder.py @@ -0,0 +1,221 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2024 Alan (alanfangemail@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder definition with lora.""" + +from typing import Optional, List + +import torch + +from wenet.transformer.convolution import ConvolutionModule +from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder +from wenet.transformer.encoder_layer import TransformerEncoderLayer +from wenet.transformer.encoder_layer import ConformerEncoderLayer +from wenet.utils.class_utils import ( + WENET_MLP_CLASSES, + WENET_ACTIVATION_CLASSES, +) +from wenet.finetune.lora.utils import WENET_LORA_ATTENTION_CLASSES + + +class LoRATransformerEncoder(TransformerEncoder): + """Transformer encoder module with lora.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + mlp_bias: bool = True, + activation_type: str = "relu", + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + mlp_type: str = 'position_wise_feed_forward', + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + lora_rank: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0, + lora_list: Optional[List[str]] = None, + ): + """ Construct TransformerEncoder + + See Encoder for the meaning of each parameter. + """ + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, query_bias, key_bias, + value_bias, mlp_bias, activation_type, + gradient_checkpointing, use_sdpa, mlp_type, + layer_norm_type, norm_eps, n_kv_head, head_dim) + activation = WENET_ACTIVATION_CLASSES[activation_type]() + mlp_class = WENET_MLP_CLASSES[mlp_type] + self.encoders = torch.nn.ModuleList([ + TransformerEncoderLayer( + output_size, + WENET_LORA_ATTENTION_CLASSES["selfattn"]( + attention_heads, output_size, attention_dropout_rate, + query_bias, key_bias, value_bias, use_sdpa, n_kv_head, + head_dim, lora_rank, lora_alpha, lora_dropout, lora_list), + mlp_class(output_size, linear_units, dropout_rate, activation, + mlp_bias), + dropout_rate, + normalize_before, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + ) for _ in range(num_blocks) + ]) + + +class LoRAConformerEncoder(ConformerEncoder): + """Conformer encoder module with lora.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = True, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + mlp_bias: bool = True, + conv_bias: bool = True, + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + mlp_type: str = 'position_wise_feed_forward', + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + lora_rank: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0, + lora_list: Optional[List[str]] = None, + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + selfattention_layer_type (str): Encoder attention layer type, + the parameter has no effect now, it's just for configure + compatibility. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + key_bias: whether use bias in attention.linear_k, False for whisper models. + """ + super().__init__( + input_size, output_size, attention_heads, linear_units, num_blocks, + dropout_rate, positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, positionwise_conv_kernel_size, + macaron_style, selfattention_layer_type, activation_type, + use_cnn_module, cnn_module_kernel, causal, cnn_module_norm, + query_bias, key_bias, value_bias, mlp_bias, conv_bias, + gradient_checkpointing, use_sdpa, mlp_type, layer_norm_type, + norm_eps, n_kv_head, head_dim) + activation = WENET_ACTIVATION_CLASSES[activation_type]() + + # self-attention module definition + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, + lora_rank, + lora_alpha, + lora_dropout, + lora_list, + ) + # feed-forward module definition + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + mlp_bias, + ) + # convolution module definition + convolution_layer_args = (output_size, cnn_module_kernel, activation, + cnn_module_norm, causal, conv_bias) + + mlp_class = WENET_MLP_CLASSES[mlp_type] + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + WENET_LORA_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args), + mlp_class(*positionwise_layer_args), + mlp_class(*positionwise_layer_args) if macaron_style else None, + ConvolutionModule( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + ) for _ in range(num_blocks) + ]) diff --git a/wenet/finetune/lora/layers.py b/wenet/finetune/lora/layers.py new file mode 100644 index 0000000000..3982ef279f --- /dev/null +++ b/wenet/finetune/lora/layers.py @@ -0,0 +1,350 @@ +# Copyright (c) 2021 microsoft +# 2023 Alan (alanfangemail@gmail.com) +# ----------------------------------------------------------------------------- +# Licensed under the MIT License (MIT). See LICENSE in the repo root for +# license information. +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math +from typing import List + + +class LoRALayer(): + + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = self.identity + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + + def identity(self, x): + return x + + +class Embedding(nn.Embedding, LoRALayer): + # LoRA implemented in a dense layer + def __init__(self, + num_embeddings: int, + embedding_dim: int, + r: int = 0, + lora_alpha: int = 1, + merge_weights: bool = True, + **kwargs): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) + LoRALayer.__init__(self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=0, + merge_weights=merge_weights) + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter( + self.weight.new_zeros((r, num_embeddings))) + self.lora_B = nn.Parameter( + self.weight.new_zeros((embedding_dim, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + + def reset_parameters(self): + nn.Embedding.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.zeros_(self.lora_A) + nn.init.normal_(self.lora_B) + + def train(self, mode: bool = True): + nn.Embedding.train(self, mode) + if mode: + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0: + temp = (self.lora_B @ self.lora_A).transpose(0, 1) + self.weight.data -= temp * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + temp = (self.lora_B @ self.lora_A).transpose(0, 1) + self.weight.data += temp * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + if self.r > 0 and not self.merged: + result = nn.Embedding.forward(self, x) + after_A = F.embedding(x, self.lora_A.transpose(0, 1), + self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, + self.sparse) + result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling + return result + else: + return nn.Embedding.forward(self, x) + + +class Linear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, + # Set this to True if the layer to replace stores weight like (fan_in, + # fan_out) + merge_weights: bool = True, + **kwargs): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=merge_weights) + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.weight.new_zeros( + (out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def T(self, w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + + def train(self, mode: bool = True): + nn.Linear.train(self, mode) + if mode: + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0: + temp = self.T(self.lora_B @ self.lora_A) + self.weight.data -= temp * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + temp = self.T(self.lora_B @ self.lora_A) + self.weight.data += temp * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + if self.r > 0 and not self.merged: + result = F.linear(x, self.T(self.weight), bias=self.bias) + result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) + @ self.lora_B.transpose(0, 1)) * self.scaling + return result + else: + return F.linear(x, self.T(self.weight), bias=self.bias) + + +class MergedLinear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__(self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + enable_lora: List[bool] = None, + fan_in_fan_out: bool = False, + merge_weights: bool = True, + **kwargs): + if enable_lora is None: + enable_lora = [False] + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=merge_weights) + assert out_features % len(enable_lora) == 0, \ + 'The length of enable_lora must divide out_features' + self.enable_lora = enable_lora + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0 and any(enable_lora): + self.lora_A = nn.Parameter( + self.weight.new_zeros((r * sum(enable_lora), in_features))) + self.lora_B = nn.Parameter( + self.weight.new_zeros( + (out_features // len(enable_lora) * sum(enable_lora), r))) + # weights for Conv1D with groups=sum(enable_lora) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + # Compute the indices + self.lora_ind = self.weight.new_zeros( + (out_features, ), dtype=torch.bool).view(len(enable_lora), -1) + self.lora_ind[enable_lora, :] = True + self.lora_ind = self.lora_ind.view(-1) + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def zero_pad(self, x): + result = x.new_zeros((len(self.lora_ind), *x.size()[1:])) + result[self.lora_ind] = x + return result + + def T(self, w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + + def merge_AB(self): + delta_w = F.conv1d(self.lora_A.unsqueeze(0), + self.lora_B.unsqueeze(-1), + groups=sum(self.enable_lora)).squeeze(0) + return self.T(delta_w) + + def train(self, mode: bool = True): + nn.Linear.train(self, mode) + if mode: + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0 and any(self.enable_lora): + self.weight.data -= self.merge_AB() * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0 and any(self.enable_lora): + self.weight.data += self.merge_AB() * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + if self.merged: + return F.linear(x, self.T(self.weight), bias=self.bias) + else: + result = F.linear(x, self.T(self.weight), bias=self.bias) + if self.r > 0: + temp = self.T(self.merge_AB().T) + result += self.lora_dropout(x) @ temp * self.scaling + return result + + +class ConvLoRA(nn.Module, LoRALayer): + + def __init__(self, + conv_module, + in_channels, + out_channels, + kernel_size, + r=0, + lora_alpha=1, + lora_dropout=0., + merge_weights=True, + **kwargs): + super(ConvLoRA, self).__init__() + self.conv = conv_module(in_channels, out_channels, kernel_size, + **kwargs) + LoRALayer.__init__(self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=merge_weights) + assert isinstance(kernel_size, int) + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter( + self.conv.weight.new_zeros( + (r * kernel_size, in_channels * kernel_size))) + self.lora_B = nn.Parameter( + self.conv.weight.new_zeros( + (out_channels // self.conv.groups * kernel_size, + r * kernel_size))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.conv.weight.requires_grad = False + self.reset_parameters() + self.merged = False + + def reset_parameters(self): + self.conv.reset_parameters() + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def train(self, mode=True): + super(ConvLoRA, self).train(mode) + if mode: + if self.merge_weights and self.merged: + if self.r > 0: + # Make sure that the weights are not merged + self.conv.weight.data -= (self.lora_B @ self.lora_A).view( + self.conv.weight.shape) * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + if self.r > 0: + # Merge the weights and mark it + self.conv.weight.data += (self.lora_B @ self.lora_A).view( + self.conv.weight.shape) * self.scaling + self.merged = True + + def forward(self, x): + if self.r > 0 and not self.merged: + return self.conv._conv_forward( + x, self.conv.weight + + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * + self.scaling, self.conv.bias) + return self.conv(x) + + +class Conv2d(ConvLoRA): + + def __init__(self, *args, **kwargs): + super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs) + + +class Conv1d(ConvLoRA): + + def __init__(self, *args, **kwargs): + super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs) + + +# Can Extend to other ones like this +class Conv3d(ConvLoRA): + + def __init__(self, *args, **kwargs): + super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs) diff --git a/wenet/finetune/lora/utils.py b/wenet/finetune/lora/utils.py new file mode 100644 index 0000000000..d6aebe9391 --- /dev/null +++ b/wenet/finetune/lora/utils.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 microsoft +# 2023 Alan (alanfangemail@gmail.com) +# ----------------------------------------------------------------------------- +# Licensed under the MIT License (MIT). See LICENSE in the repo root for +# license information. +# ----------------------------------------------------------------------------- + +import logging +import torch +import torch.nn as nn + +from typing import Dict + +from wenet.finetune.lora.attention import (LoRARelPositionMultiHeadedAttention, + LoRAMultiHeadedAttention) +from wenet.finetune.lora.layers import LoRALayer + +WENET_LORA_ATTENTION_CLASSES = { + "selfattn": LoRAMultiHeadedAttention, + "rel_selfattn": LoRARelPositionMultiHeadedAttention, +} + + +def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: + logging.info('freezing all params except lora module.') + for n, p in model.named_parameters(): + if 'lora_' not in n: + p.requires_grad = False + if bias == 'none': + return + elif bias == 'all': + for n, p in model.named_parameters(): + if 'bias' in n: + p.requires_grad = True + elif bias == 'lora_only': + for m in model.modules(): + if isinstance(m, LoRALayer) and \ + hasattr(m, 'bias') and \ + m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError + + +def lora_state_dict(model: nn.Module, + bias: str = 'none') -> Dict[str, torch.Tensor]: + my_state_dict = model.state_dict() + if bias == 'none': + return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} + elif bias == 'all': + return { + k: my_state_dict[k] + for k in my_state_dict if 'lora_' in k or 'bias' in k + } + elif bias == 'lora_only': + to_return = {} + for k in my_state_dict: + if 'lora_' in k: + to_return[k] = my_state_dict[k] + bias_name = k.split('lora_')[0] + 'bias' + if bias_name in my_state_dict: + to_return[bias_name] = my_state_dict[bias_name] + return to_return + else: + raise NotImplementedError diff --git a/wenet/k2/model.py b/wenet/k2/model.py index bbc580cdc3..d76d89dd78 100644 --- a/wenet/k2/model.py +++ b/wenet/k2/model.py @@ -54,7 +54,7 @@ def __init__( if self.lfmmi_dir != '': self.load_lfmmi_resource() - @torch.jit.ignore(drop=True) + @torch.jit.unused def _forward_ctc( self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, text: torch.Tensor, @@ -63,7 +63,7 @@ def _forward_ctc( text) return loss_ctc, ctc_probs - @torch.jit.ignore(drop=True) + @torch.jit.unused def load_lfmmi_resource(self): try: import icefall @@ -94,7 +94,7 @@ def load_lfmmi_resource(self): assert len(arr) == 2 self.word_table[int(arr[1])] = arr[0] - @torch.jit.ignore(drop=True) + @torch.jit.unused def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text): try: import k2 diff --git a/wenet/paraformer/layers.py b/wenet/paraformer/layers.py index ff5b849dc9..d17280d8a5 100644 --- a/wenet/paraformer/layers.py +++ b/wenet/paraformer/layers.py @@ -282,7 +282,7 @@ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward_layers_checkpointed(self, xs: torch.Tensor, chunk_masks: torch.Tensor, pos_emb: torch.Tensor, @@ -290,8 +290,12 @@ def forward_layers_checkpointed(self, xs: torch.Tensor, for layer in self.encoders0: xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) for layer in self.encoders: - xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, chunk_masks, - pos_emb, mask_pad) + xs, _, _, _ = ckpt.checkpoint(layer.__call__, + xs, + chunk_masks, + pos_emb, + mask_pad, + use_reentrant=False) return xs @@ -471,7 +475,7 @@ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor, x = layer(x) return x - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward_layers_checkpointed(self, x: torch.Tensor, tgt_mask: torch.Tensor, memory: torch.Tensor, @@ -480,8 +484,12 @@ def forward_layers_checkpointed(self, x: torch.Tensor, if i == 0: x, _, _, _ = layer(x, tgt_mask, memory, memory_mask) else: - x, _, _, _ = ckpt.checkpoint(layer.__call__, x, tgt_mask, - memory, memory_mask) + x, _, _, _ = ckpt.checkpoint(layer.__call__, + x, + tgt_mask, + memory, + memory_mask, + use_reentrant=False) for layer in self.decoders3: x = layer(x) return x diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index 64b3587ec4..be19f15b49 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -148,7 +148,7 @@ def __init__(self, # labels: 你 好 we@@ net eos self.add_eos = add_eos - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward( self, batch: Dict, @@ -232,7 +232,7 @@ def _calc_att_loss( ignore_label=self.ignore_id) return loss_att, acc_att - @torch.jit.ignore(drop=True) + @torch.jit.unused def _sampler(self, encoder_out, encoder_out_mask, ys_pad, ys_pad_lens, pre_acoustic_embeds): device = encoder_out.device diff --git a/wenet/ssl/w2vbert/w2vbert_model.py b/wenet/ssl/w2vbert/w2vbert_model.py index b874595297..27db0abf1e 100644 --- a/wenet/ssl/w2vbert/w2vbert_model.py +++ b/wenet/ssl/w2vbert/w2vbert_model.py @@ -158,7 +158,7 @@ def _reset_parameter(module: torch.nn.Module): _reset_parameter(conv1) _reset_parameter(conv2) - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward( self, batch: Dict, diff --git a/wenet/ssl/wav2vec2/wav2vec2_model.py b/wenet/ssl/wav2vec2/wav2vec2_model.py index 69d5af0222..9cbd0c3b32 100644 --- a/wenet/ssl/wav2vec2/wav2vec2_model.py +++ b/wenet/ssl/wav2vec2/wav2vec2_model.py @@ -217,7 +217,7 @@ def _reset_parameter(module: torch.nn.Module): _reset_parameter(conv1) _reset_parameter(conv2) - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward( self, batch: Dict, diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index f352271abf..fb5e6c6920 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -74,7 +74,7 @@ def __init__( normalize_length=length_normalized_loss, ) - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward( self, batch: dict, @@ -133,7 +133,7 @@ def forward( "th_accuracy": acc_att, } - @torch.jit.ignore(drop=True) + @torch.jit.unused def _forward_ctc( self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, text: torch.Tensor, @@ -231,7 +231,7 @@ def _forward_encoder( ) # (B, maxlen, encoder_dim) return encoder_out, encoder_mask - @torch.jit.ignore(drop=True) + @torch.jit.unused def ctc_logprobs(self, encoder_out: torch.Tensor, blank_penalty: float = 0.0, @@ -245,6 +245,9 @@ def ctc_logprobs(self, return ctc_probs + def tie_or_clone_weights(self, jit_mode: bool = True): + self.decoder.tie_or_clone_weights(jit_mode) + def decode( self, methods: List[str], diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 54b76daade..230481acaf 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -21,7 +21,9 @@ import torch from torch import nn -from wenet.utils.rope_utils import llama_apply_rotary_emb +from wenet.utils.rope_utils import WENET_APPLY_ROTARY_EMB + +T_CACHE = Tuple[torch.Tensor, torch.Tensor] class MultiHeadedAttention(nn.Module): @@ -64,7 +66,7 @@ def __init__(self, self.inner_kv_dim = self.inner_dim n_kv_head = n_head # We assume d_v always equals d_k - self.d_k = n_feat // n_head + self.d_k = self.inner_dim // n_head assert self.d_k == self.inner_kv_dim // n_kv_head self.h = n_head self.h_kv = n_kv_head @@ -78,7 +80,10 @@ def __init__(self, self.use_sdpa = use_sdpa self.dropout_rate = dropout_rate - def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor: + def _forward_linearx(self, + name: str, + x: torch.Tensor, + head_first: bool = True) -> torch.Tensor: assert x.ndim >= 3 if name == 'query': x = self.linear_q(x) @@ -96,7 +101,9 @@ def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor: # split last dim x = x.view(x_shape) - x = x.transpose(-3, -2) # (batch, ..., head or head_kv, time, d_k) + if head_first: + x = x.transpose(-3, + -2) # (batch, ..., head or head_kv, time, d_k) return x def forward_qkv( @@ -169,6 +176,56 @@ def forward_attention( x = x.view(x_shape) # (batch, ..., time1, d_model) return self.linear_out(x) # (batch, ..., time1, d_model) + def _update_kv_and_cache( + self, + k: torch.Tensor, + v: torch.Tensor, + cache: T_CACHE, + head_first: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE]: + new_cache = cache + seq_axis = -2 if head_first else -3 + head_axis = -3 if head_first else -2 + if not self.training: + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + key_cache, value_cache = cache + if key_cache.size(0) > 0: + k = torch.cat([key_cache, k], dim=seq_axis) + if value_cache.size(0) > 0: + v = torch.cat([value_cache, v], dim=seq_axis) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + # new_cache = torch.cat((k, v), dim=-1) if not self.training else cache + new_cache = (k, v) + # for multi query or multi group attention + if self.h_kv != self.h and self.h_kv != 1: + k = torch.repeat_interleave( + k, + self.h // self.h_kv, + dim=head_axis, + ) + v = torch.repeat_interleave( + v, + self.h // self.h_kv, + dim=-head_axis, + ) + return k, v, new_cache + def forward( self, query: torch.Tensor, @@ -176,8 +233,8 @@ def forward( value: torch.Tensor, mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), pos_emb: torch.Tensor = torch.empty(0), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: + cache: T_CACHE = (torch.zeros(0, 0, 0, 0), torch.zeros(0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, T_CACHE]: """Compute scaled dot product attention. Args: @@ -209,45 +266,7 @@ def forward( """ q, k, v = self.forward_qkv(query, key, value) - - # NOTE(xcsong): - # when export onnx model, for 1st chunk, we feed - # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) - # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). - # In all modes, `if cache.size(0) > 0` will alwayse be `True` - # and we will always do splitting and - # concatnation(this will simplify onnx export). Note that - # it's OK to concat & split zero-shaped tensors(see code below). - # when export jit model, for 1st chunk, we always feed - # cache(0, 0, 0, 0) since jit supports dynamic if-branch. - # >>> a = torch.ones((1, 2, 0, 4)) - # >>> b = torch.ones((1, 2, 3, 4)) - # >>> c = torch.cat((a, b), dim=2) - # >>> torch.equal(b, c) # True - # >>> d = torch.split(a, 2, dim=-1) - # >>> torch.equal(d[0], d[1]) # True - if cache.size(0) > 0: - key_cache, value_cache = torch.split(cache, - cache.size(-1) // 2, - dim=-1) - k = torch.cat([key_cache, k], dim=2) - v = torch.cat([value_cache, v], dim=2) - # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's - # non-trivial to calculate `next_cache_start` here. - new_cache = torch.cat((k, v), dim=-1) - - # for multi query or multi group attention - if self.h_kv != self.h: - k = torch.repeat_interleave( - k, - self.h // self.h_kv, - dim=-3, - ) - v = torch.repeat_interleave( - v, - self.h // self.h_kv, - dim=-3, - ) + k, v, new_cache = self._update_kv_and_cache(k, v, cache) if not self.use_sdpa: scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) @@ -331,8 +350,8 @@ def forward( value: torch.Tensor, mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), pos_emb: torch.Tensor = torch.empty(0), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: + cache: T_CACHE = (torch.zeros((0, 0, 0, 0)), torch.zeros((0, 0, 0, 0))) + ) -> Tuple[torch.Tensor, T_CACHE]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (torch.Tensor): Query tensor (#batch, time1, size). @@ -353,46 +372,7 @@ def forward( """ q, k, v = self.forward_qkv(query, key, value) q = q.transpose(1, 2) # (batch, time1, head, d_k) - - # NOTE(xcsong): - # when export onnx model, for 1st chunk, we feed - # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) - # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). - # In all modes, `if cache.size(0) > 0` will alwayse be `True` - # and we will always do splitting and - # concatnation(this will simplify onnx export). Note that - # it's OK to concat & split zero-shaped tensors(see code below). - # when export jit model, for 1st chunk, we always feed - # cache(0, 0, 0, 0) since jit supports dynamic if-branch. - # >>> a = torch.ones((1, 2, 0, 4)) - # >>> b = torch.ones((1, 2, 3, 4)) - # >>> c = torch.cat((a, b), dim=2) - # >>> torch.equal(b, c) # True - # >>> d = torch.split(a, 2, dim=-1) - # >>> torch.equal(d[0], d[1]) # True - if cache.size(0) > 0: - key_cache, value_cache = torch.split(cache, - cache.size(-1) // 2, - dim=-1) - k = torch.cat([key_cache, k], dim=2) - v = torch.cat([value_cache, v], dim=2) - - # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's - # non-trivial to calculate `next_cache_start` here. - new_cache = torch.cat((k, v), dim=-1) - - # for multi query or multi groups attention - if self.h_kv != self.h: - k = torch.repeat_interleave( - k, - self.h // self.h_kv, - dim=-3, - ) - v = torch.repeat_interleave( - v, - self.h // self.h_kv, - dim=-3, - ) + k, v, new_cache = self._update_kv_and_cache(k, v, cache) n_batch_pos = pos_emb.size(0) p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) @@ -462,20 +442,21 @@ def forward( value: torch.Tensor, mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), pos_emb: torch.Tensor = torch.empty(0), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: + cache: T_CACHE = (torch.zeros((0, 0, 0, 0)), torch.zeros((0, 0, 0, 0))) + ) -> Tuple[torch.Tensor, T_CACHE]: del pos_emb - if cache.size(0) > 0: + key_cache, value_cache = cache + assert key_cache.size(0) == value_cache.size(0) + if key_cache.size(0) > 0: assert not self.training q = self._forward_linearx('query', query) - k, v = torch.split(cache, cache.size(-1) // 2, dim=-1) + k, v = key_cache, value_cache else: q, k, v = self.forward_qkv(query, key, value) - new_cache = torch.cat((k, v), dim=-1) - + new_cache = (k, v) if not self.training else cache # for multi query or multi groups attention - if self.h_kv != self.h: + if self.h_kv != self.h and self.h_kv != 1: k = torch.repeat_interleave( k, self.h // self.h_kv, @@ -486,7 +467,6 @@ def forward( self.h // self.h_kv, dim=-3, ) - B = query.size(0) Beams = 1 if B != k.size(0): @@ -559,17 +539,11 @@ def forward( value: torch.Tensor, mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), pos_emb: torch.Tensor = torch.empty(0), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: + cache: T_CACHE = (torch.zeros((0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, T_CACHE]: del pos_emb q, k, v = self.forward_qkv(query, key, value) - if cache.size(0) > 0: - key_cache, value_cache = torch.split(cache, - cache.size(-1) // 2, - dim=-1) - k = torch.cat([key_cache, k], dim=2) - v = torch.cat([value_cache, v], dim=2) - new_cache = torch.cat((k, v), dim=-1) + k, v, new_cache = self._update_kv_and_cache(k, v, cache) rel_k = self.rel_k_embed( self._relative_indices(k.size(2), query.device)) # (t2, t2, d_k) @@ -615,9 +589,11 @@ def __init__(self, value_bias: bool = True, use_sdpa: bool = False, n_kv_head: Optional[int] = None, - head_dim: Optional[int] = None): + head_dim: Optional[int] = None, + style='google'): super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias, value_bias, use_sdpa, n_kv_head, head_dim) + self.style = style def forward( self, @@ -626,8 +602,8 @@ def forward( value: torch.Tensor, mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), pos_emb: torch.Tensor = torch.empty(0), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: + cache: T_CACHE = (torch.zeros((0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, T_CACHE]: """Compute rope scaled dot product attention. Args: @@ -658,11 +634,13 @@ def forward( and `head * d_k == size` """ - q, k, v = self.forward_qkv(query, key, value) + q = self._forward_linearx('query', query, head_first=False) + k = self._forward_linearx('key', key, head_first=False) + v = self._forward_linearx('value', value, head_first=False) # NOTE(Mddct): In order to make the code easier to read, # these two lines are not placed in MultiHeadedAttention. - q = llama_apply_rotary_emb(q, pos_emb) - k = llama_apply_rotary_emb(k, pos_emb) + q = WENET_APPLY_ROTARY_EMB[self.style](q, pos_emb) + k = WENET_APPLY_ROTARY_EMB[self.style](k, pos_emb) # see above if cache.size(0) > 0: key_cache, value_cache = torch.split(cache, @@ -670,20 +648,16 @@ def forward( dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) - new_cache = torch.cat((k, v), dim=-1) - - if self.h_kv != self.h: - k = torch.repeat_interleave( - k, - self.h // self.h_kv, - dim=1, - ) - v = torch.repeat_interleave( - v, - self.h // self.h_kv, - dim=1, - ) - + new_cache = torch.cat( + (k, v), dim=-1) if not self.training else torch.empty(0, 0, 0, 0) + + k, v, new_cache = self._update_kv_and_cache(k, + v, + cache, + head_first=False) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) if not self.use_sdpa: scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask), new_cache diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index 4b165c6aa0..0c4fab62af 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -15,9 +15,11 @@ """Decoder definition.""" from typing import Dict, Tuple, List, Optional +import os import torch import torch.utils.checkpoint as ckpt import logging +from wenet.transformer.attention import T_CACHE from wenet.transformer.decoder_layer import DecoderLayer from wenet.utils.class_utils import ( @@ -76,16 +78,18 @@ def __init__( query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, activation_type: str = "relu", gradient_checkpointing: bool = False, tie_word_embedding: bool = False, use_sdpa: bool = False, - mlp_type: str = 'position_wise_feed_forward', layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): super().__init__() attention_dim = encoder_output_size @@ -121,8 +125,13 @@ def __init__( attention_heads, attention_dim, src_attention_dropout_rate, query_bias, key_bias, value_bias, use_sdpa, n_kv_head, head_dim) if src_attention else None, - mlp_class(attention_dim, linear_units, dropout_rate, - activation, mlp_bias), + mlp_class(attention_dim, + linear_units, + dropout_rate, + activation, + mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated), dropout_rate, normalize_before, layer_norm_type, @@ -199,14 +208,19 @@ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor, memory_mask) return x - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward_layers_checkpointed(self, x: torch.Tensor, tgt_mask: torch.Tensor, memory: torch.Tensor, memory_mask: torch.Tensor) -> torch.Tensor: for layer in self.decoders: x, tgt_mask, memory, memory_mask = ckpt.checkpoint( - layer.__call__, x, tgt_mask, memory, memory_mask) + layer.__call__, + x, + tgt_mask, + memory, + memory_mask, + use_reentrant=False) return x def forward_one_step( @@ -215,7 +229,7 @@ def forward_one_step( memory_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor, - cache: Dict[str, Dict[str, torch.Tensor]], + cache: Dict[str, Dict[str, T_CACHE]], ) -> torch.Tensor: """Forward one step. This is only used for decoding. @@ -269,14 +283,19 @@ def forward_one_step( def tie_or_clone_weights(self, jit_mode: bool = True): """Tie or clone module weights (between word_emb and output_layer) depending of whether we are using TorchScript or not""" + rank = int(os.environ.get('RANK', 0)) if not self.use_output_layer: return + if not self.tie_word_embedding: + return if jit_mode: - logging.info("clone emb.weight to output.weight") + if rank == 0: + logging.info("clone emb.weight to output.weight") self.output_layer.weight = torch.nn.Parameter( self.embed[0].weight.clone()) else: - logging.info("tie emb.weight with output.weight") + if rank == 0: + logging.info("tie emb.weight with output.weight") self.output_layer.weight = self.embed[0].weight if getattr(self.output_layer, "bias", None) is not None: @@ -327,10 +346,11 @@ def __init__( input_layer: str = "embed", use_output_layer: bool = True, normalize_before: bool = True, + src_attention: bool = True, query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, + activation_type: str = "relu", gradient_checkpointing: bool = False, tie_word_embedding: bool = False, use_sdpa: bool = False, @@ -338,6 +358,10 @@ def __init__( norm_eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): super().__init__() @@ -356,9 +380,11 @@ def __init__( input_layer, use_output_layer, normalize_before, + src_attention=src_attention, query_bias=query_bias, key_bias=key_bias, value_bias=value_bias, + activation_type=activation_type, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, use_sdpa=use_sdpa, @@ -366,7 +392,10 @@ def __init__( norm_eps=norm_eps, n_kv_head=n_kv_head, head_dim=head_dim, - ) + mlp_type=mlp_type, + mlp_bias=mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated) self.right_decoder = TransformerDecoder( vocab_size, @@ -381,10 +410,11 @@ def __init__( input_layer, use_output_layer, normalize_before, + src_attention=src_attention, query_bias=query_bias, key_bias=key_bias, value_bias=value_bias, - mlp_bias=mlp_bias, + activation_type=activation_type, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, use_sdpa=use_sdpa, @@ -392,7 +422,10 @@ def __init__( norm_eps=norm_eps, n_kv_head=n_kv_head, head_dim=head_dim, - ) + mlp_type=mlp_type, + mlp_bias=mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated) def forward( self, diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py index 3c06265355..f08d597b3c 100644 --- a/wenet/transformer/decoder_layer.py +++ b/wenet/transformer/decoder_layer.py @@ -17,6 +17,7 @@ import torch from torch import nn +from wenet.transformer.attention import T_CACHE from wenet.utils.class_utils import WENET_NORM_CLASSES @@ -70,7 +71,7 @@ def forward( tgt_mask: torch.Tensor, memory: torch.Tensor, memory_mask: torch.Tensor, - cache: Optional[Dict[str, Optional[torch.Tensor]]] = None + cache: Optional[Dict[str, Optional[T_CACHE]]] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute decoded features. @@ -105,7 +106,7 @@ def forward( if att_cache is None: tgt_q = tgt tgt_q_mask = tgt_mask - att_cache = torch.empty(0, 0, 0, 0) + att_cache = (torch.empty(0, 0, 0, 0), torch.empty(0, 0, 0, 0)) else: tgt_q = tgt[:, -1:, :] residual = residual[:, -1:, :] @@ -129,7 +130,8 @@ def forward( if self.normalize_before: x = self.norm2(x) if cross_att_cache is None: - cross_att_cache = torch.empty(0, 0, 0, 0) + cross_att_cache = (torch.empty(0, 0, 0, + 0), torch.empty(0, 0, 0, 0)) x, new_cross_cache = self.src_attn(x, memory, memory, diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index 2efa2f5fdc..162de63dc1 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -205,13 +205,15 @@ def __init__(self, head_dim: int, dropout_rate: float, max_len: int = 1500, - rope_theta=10000.0): + rope_theta=10000.0, + scale: bool = True): super().__init__(d_model, dropout_rate=dropout_rate, max_len=max_len) delattr(self, 'pe') - - pe = precompute_freqs_cis(head_dim, max_len * 2, rope_theta) - self.register_buffer("pe", pe.unsqueeze(0)) + self.max_len = max_len * 2 + pe = precompute_freqs_cis(head_dim, self.max_len, rope_theta) + self.register_buffer("pe", torch.view_as_real(pe.unsqueeze(0))) self.dropout_rate = dropout_rate + self.scale = scale def forward( self, @@ -219,13 +221,34 @@ def forward( offset: Union[int, torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]: - pos_emb = self.position_encoding(offset, x.size(1), False) - pos_emb = pos_emb.unsqueeze(1) # [1, 1, seq, head_dim//2] + pos_emb = self.position_encoding(offset, x.size(1), True) + pos_emb = pos_emb.unsqueeze(2) # [1, 1, seq, head_dim//2] # NOTE(Mddct): some model don't scale - # TODO(Mddct): fix - x = x * self.xscale - # NOTE(Mddct) dropout don't suuport complex float for pos_emb - return self.dropout(x), self.dropout_complex(pos_emb) + if self.scale: + x = x * self.xscale + return self.dropout(x), pos_emb + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int, + apply_dropout: bool = True) -> torch.Tensor: + + pe = torch.view_as_complex(self.pe) + if isinstance(offset, int): + assert offset + size <= self.max_len + pos_emb = pe[:, offset:offset + size] + else: + assert torch.max(offset) + size <= self.max_len + index = offset.unsqueeze(1) + torch.arange(0, size).to( + offset.device) # B X T + flag = index > 0 + # remove negative offset + index = index * flag + pos_emb = F.embedding(index, pe[0]) # B X T X head_dim//2 + if apply_dropout: + # NOTE(Mddct) dropout don't suuport complex float for pos_emb + pos_emb = self.dropout_complex(pos_emb) + return pos_emb def dropout_complex(self, x): mask = torch.nn.functional.dropout( diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 0f5ccef861..9cfd260ea9 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -184,15 +184,18 @@ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward_layers_checkpointed(self, xs: torch.Tensor, chunk_masks: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor) -> torch.Tensor: for layer in self.encoders: - xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs, - chunk_masks, pos_emb, - mask_pad) + xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, + xs, + chunk_masks, + pos_emb, + mask_pad, + use_reentrant=False) return xs def forward_chunk( @@ -263,12 +266,20 @@ def forward_chunk( # NOTE(xcsong): Before layer.forward # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) - xs, _, new_att_cache, new_cnn_cache = layer( + if elayers == 0: + kv_cache = (att_cache, att_cache) + else: + i_kv_cache = att_cache[i:i + 1] + size = att_cache.size(-1) // 2 + kv_cache = (i_kv_cache[:, :, :, :size], i_kv_cache[:, :, :, + size:]) + xs, _, new_kv_cache, new_cnn_cache = layer( xs, att_mask, pos_emb, - att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, + att_cache=kv_cache, cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) + new_att_cache = torch.cat(new_kv_cache, dim=-1) # NOTE(xcsong): After layer.forward # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) @@ -371,16 +382,18 @@ def __init__( query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, activation_type: str = "relu", gradient_checkpointing: bool = False, use_sdpa: bool = False, - mlp_type: str = 'position_wise_feed_forward', layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, selfattention_layer_type: str = "selfattn", + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): """ Construct TransformerEncoder @@ -404,8 +417,13 @@ def __init__( attention_heads, output_size, attention_dropout_rate, query_bias, key_bias, value_bias, use_sdpa, n_kv_head, head_dim), - mlp_class(output_size, linear_units, dropout_rate, activation, - mlp_bias), + mlp_class(output_size, + linear_units, + dropout_rate, + activation, + mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated), dropout_rate, normalize_before, layer_norm_type=layer_norm_type, @@ -445,15 +463,17 @@ def __init__( query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, conv_bias: bool = True, gradient_checkpointing: bool = False, use_sdpa: bool = False, - mlp_type: str = 'position_wise_feed_forward', layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): """Construct ConformerEncoder @@ -500,6 +520,8 @@ def __init__( dropout_rate, activation, mlp_bias, + n_expert, + n_expert_activated, ) # convolution module definition convolution_layer_args = (output_size, cnn_module_kernel, activation, diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index 641dbc53ab..068d3f99b6 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -15,10 +15,12 @@ # Modified from ESPnet(https://github.com/espnet/espnet) """Encoder self-attention layer definition.""" +from functools import partial from typing import Optional, Tuple import torch from torch import nn +from wenet.transformer.attention import T_CACHE from wenet.utils.class_utils import WENET_NORM_CLASSES @@ -48,14 +50,21 @@ def __init__( normalize_before: bool = True, layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, + rms_norm_offset: bool = True, ): """Construct an EncoderLayer object.""" super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward assert layer_norm_type in ['layer_norm', 'rms_norm'] - self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) - self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) + norm_class = WENET_NORM_CLASSES[layer_norm_type] + if layer_norm_type == "rms_norm": + norm_class = partial( + norm_class, + add_unit_offset=rms_norm_offset, + ) + self.norm1 = norm_class(size, eps=norm_eps) + self.norm2 = norm_class(size, eps=norm_eps) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before @@ -66,9 +75,10 @@ def forward( mask: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + att_cache: T_CACHE = (torch.zeros( + (0, 0, 0, 0)), torch.zeros((0, 0, 0, 0))), cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]: """Compute encoded features. Args: @@ -180,9 +190,10 @@ def forward( mask: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + att_cache: T_CACHE = (torch.zeros( + (0, 0, 0, 0)), torch.zeros((0, 0, 0, 0))), cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]: """Compute encoded features. Args: diff --git a/wenet/transformer/norm.py b/wenet/transformer/norm.py index 2c3756f13f..8039228630 100644 --- a/wenet/transformer/norm.py +++ b/wenet/transformer/norm.py @@ -9,14 +9,19 @@ def __init__( self, dim: int, eps: float = 1e-6, + add_unit_offset: bool = True, ): super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.ones(dim)) + self.add_unit_offset = add_unit_offset def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): x = self._norm(x.float()).type_as(x) - return x * self.weight + if self.add_unit_offset: + return x * (1 + self.weight) + else: + return x * self.weight diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index 7d6ab3251e..e4c38e0f99 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -31,12 +31,14 @@ class PositionwiseFeedForward(torch.nn.Module): """ def __init__( - self, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.ReLU(), - bias: bool = True, + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + *dummy_args, + **dummy_kwargs, ): """Construct a PositionwiseFeedForward object.""" super(PositionwiseFeedForward, self).__init__() @@ -66,7 +68,7 @@ class MoEFFNLayer(torch.nn.Module): https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 Args: n_expert: number of expert. - n_expert_per_token: The actual number of experts used for each frame + n_expert_activated: The actual number of experts used for each frame idim (int): Input dimenstion. hidden_units (int): The number of hidden units. dropout_rate (float): Dropout rate. @@ -74,22 +76,23 @@ class MoEFFNLayer(torch.nn.Module): """ def __init__( - self, - n_expert: int, - n_expert_per_token: int, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.ReLU(), - bias: bool = False, + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = False, + n_expert: int = 8, + n_expert_activated: int = 2, ): super(MoEFFNLayer, self).__init__() - bias = False - self.gate = torch.nn.Linear(idim, n_expert, bias=bias) + self.gate = torch.nn.Linear(idim, n_expert, bias=False) self.experts = torch.nn.ModuleList( - PositionwiseFeedForward(idim, hidden_units, dropout_rate, - activation) for _ in range(n_expert)) - self.n_expert_per_token = n_expert_per_token + PositionwiseFeedForward( + idim, hidden_units, dropout_rate, activation, bias=bias) + for _ in range(n_expert)) + self.n_expert = n_expert + self.n_expert_activated = n_expert_activated def forward(self, xs: torch.Tensor) -> torch.Tensor: """Foward function. @@ -103,18 +106,18 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: ) # batch size, sequence length, embedding dimension (idim) xs = xs.view(-1, D) # (B*L, D) router = self.gate(xs) # (B*L, n_expert) - logits, indices = torch.topk( - router, self.n_expert_per_token - ) # probs:(B*L, n_expert), indices: (B*L, n_expert) + logits, selected_experts = torch.topk( + router, self.n_expert_activated + ) # probs:(B*L, n_expert_activated), selected_exp: (B*L, n_expert_activated) weights = torch.nn.functional.softmax( logits, dim=1, - dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) + dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_activated) output = torch.zeros_like(xs) # (B*L, D) for i, expert in enumerate(self.experts): - mask = indices == i - batch_idx, ith_expert = torch.where(mask) - output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( - xs[batch_idx]) + mask = selected_experts == i + token_ids, ith_expert = torch.where(mask) + output[token_ids] += weights[token_ids, ith_expert, None] * expert( + xs[token_ids]) return output.view(B, L, D) @@ -123,12 +126,14 @@ class GatedVariantsMLP(torch.nn.Module): """ def __init__( - self, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.GELU(), - bias: bool = True, + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.GELU(), + bias: bool = True, + *dummy_args, + **dummy_kwargs, ): """Construct a PositionwiseFeedForward object.""" super(GatedVariantsMLP, self).__init__() @@ -140,7 +145,7 @@ def __init__( # w_2 as down proj self.w_2 = torch.nn.Linear(hidden_units, idim, bias=bias) - def forward(self, x): + def forward(self, x) -> torch.Tensor: """Foward function. Args: xs: input tensor (B, L, D) diff --git a/wenet/transformer/search.py b/wenet/transformer/search.py index 862edb8637..9919a44dba 100644 --- a/wenet/transformer/search.py +++ b/wenet/transformer/search.py @@ -320,7 +320,8 @@ def attention_beam_search( -1, 1).repeat([1, beam_size]) * beam_size).view(-1) # (B*N) cache_index = base_cache_index + cache_index cache['self_att_cache'] = { - i_layer: torch.index_select(value, dim=0, index=cache_index) + i_layer: (torch.index_select(value[0], dim=0, index=cache_index), + torch.index_select(value[1], dim=0, index=cache_index)) for (i_layer, value) in cache['self_att_cache'].items() } # NOTE(Mddct): we don't need select cross att here diff --git a/wenet/utils/checkpoint.py b/wenet/utils/checkpoint.py index 42fc8fe670..8a2dfba614 100644 --- a/wenet/utils/checkpoint.py +++ b/wenet/utils/checkpoint.py @@ -24,14 +24,17 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: - logging.info('Checkpoint: loading from checkpoint %s' % path) - checkpoint = torch.load(path, map_location='cpu') + rank = int(os.environ.get('RANK', 0)) + logging.info('[Rank {}] Checkpoint: loading from checkpoint {}'.format( + rank, path)) + checkpoint = torch.load(path, map_location='cpu', mmap=True) missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) - for key in missing_keys: - logging.info("missing tensor: {}".format(key)) - for key in unexpected_keys: - logging.info("unexpected tensor: {}".format(key)) + if rank == 0: + for key in missing_keys: + logging.info("missing tensor: {}".format(key)) + for key in unexpected_keys: + logging.info("unexpected tensor: {}".format(key)) info_path = re.sub('.pt$', '.yaml', path) configs = {} if os.path.exists(info_path): @@ -40,29 +43,36 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: return configs +def save_state_dict_and_infos(state_dict, path: str, infos=None): + rank = int(os.environ.get('RANK', 0)) + logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format( + rank, path)) + torch.save(state_dict, path) + info_path = re.sub('.pt$', '.yaml', path) + if infos is None: + infos = {} + infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + with open(info_path, 'w') as fout: + data = yaml.dump(infos) + fout.write(data) + + def save_checkpoint(model: torch.nn.Module, path: str, infos=None): ''' Args: infos (dict or None): any info you want to save. ''' - logging.info('Checkpoint: save to checkpoint %s' % path) if isinstance(model, torch.nn.DataParallel): state_dict = model.module.state_dict() elif isinstance(model, torch.nn.parallel.DistributedDataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() - torch.save(state_dict, path) - info_path = re.sub('.pt$', '.yaml', path) - if infos is None: - infos = {} - infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') - with open(info_path, 'w') as fout: - data = yaml.dump(infos) - fout.write(data) + save_state_dict_and_infos(state_dict, path, infos) def filter_modules(model_state_dict, modules): + rank = int(os.environ.get('RANK', 0)) new_mods = [] incorrect_mods = [] mods_model = model_state_dict.keys() @@ -71,7 +81,7 @@ def filter_modules(model_state_dict, modules): new_mods += [mod] else: incorrect_mods += [mod] - if incorrect_mods: + if incorrect_mods and rank == 0: logging.warning( "module(s) %s don't match or (partially match) " "available modules in model.", diff --git a/wenet/utils/common.py b/wenet/utils/common.py index 282d73179f..cdc7610f13 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -321,6 +321,19 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return mask +def get_nested_attribute(obj, attr_path): + if isinstance(obj, torch.nn.parallel.DistributedDataParallel): + obj = obj.module + attributes = attr_path.split('.') + for attr in attributes: + obj = getattr(obj, attr) + return obj + + +def lrs_to_str(lrs: List): + return " ".join(["{:.4e}".format(lr) for lr in lrs]) + + class StepTimer: """Utility class for measuring steps/second.""" @@ -338,3 +351,9 @@ def steps_per_second(self, cur_step, restart=True): self.start() self.last_iteration = float(cur_step) return value + + +def tensor_to_scalar(x): + if torch.is_tensor(x): + return x.item() + return x diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index bd9db93382..45f2739e2d 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -31,7 +31,7 @@ class Executor: def __init__(self, global_step: int = 0): - self.step = global_step + self.step = global_step + 1 self.train_step_timer = None self.cv_step_timer = None @@ -68,8 +68,9 @@ def train(self, model, optimizer, scheduler, train_data_loader, # Disable gradient synchronizations across DDP processes. # Within this context, gradients will be accumulated on module # variables, which will later be synchronized. - if info_dict.get("train_engine", "torch_ddp") == "torch_ddp" and \ - (batch_idx + 1) % info_dict["accum_grad"] != 0: + if info_dict.get("train_engine", "torch_ddp") in [ + "torch_ddp", "torch_fsdp" + ] and (batch_idx + 1) % info_dict["accum_grad"] != 0: context = model.no_sync # Used for single gpu training and DDP gradient synchronization # processes. @@ -84,9 +85,15 @@ def train(self, model, optimizer, scheduler, train_data_loader, info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) + # write training: tensorboard && log + log_per_step(writer, info_dict, timer=self.train_step_timer) save_interval = info_dict.get('save_interval', sys.maxsize) - if self.step % save_interval == 0 and self.step != 0 \ - and (batch_idx + 1) % info_dict["accum_grad"] == 0: + if (self.step + + 1) % save_interval == 0 and self.step != 0 and ( + batch_idx + 1) % info_dict["accum_grad"] == 0: + import torch.distributed as dist + # Ensure all ranks start CV at the same time in step mode + dist.barrier() loss_dict = self.cv(model, cv_data_loader, configs) model.train() info_dict.update({ @@ -96,15 +103,17 @@ def train(self, model, optimizer, scheduler, train_data_loader, loss_dict, "save_time": datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), - "lr": - optimizer.param_groups[0]['lr'] + "lrs": + [group['lr'] for group in optimizer.param_groups] }) save_model(model, info_dict) - log_per_step(writer, info_dict, timer=self.train_step_timer) + # write final cv: tensorboard + log_per_step(writer, info_dict) + # Ensure all ranks start Train at the same time in step mode + dist.barrier() self.step += 1 if (batch_idx + 1) % info_dict["accum_grad"] == 0 else 0 - def cv(self, model, cv_data_loader, configs): ''' Cross validation on ''' @@ -138,7 +147,7 @@ def cv(self, model, cv_data_loader, configs): loss_value = loss_value.item() loss_dict[loss_name] = loss_dict.get(loss_name, 0) + \ loss_value * num_utts - + # write cv: log log_per_step(writer=None, info_dict=info_dict, timer=self.cv_step_timer) diff --git a/wenet/utils/fsdp_utils.py b/wenet/utils/fsdp_utils.py new file mode 100644 index 0000000000..77ca195953 --- /dev/null +++ b/wenet/utils/fsdp_utils.py @@ -0,0 +1,118 @@ +from functools import partial +import os +from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP, + FullStateDictConfig, StateDictType) + +from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy, + transformer_auto_wrap_policy) +from wenet.LLM.decoder import DecoderOnly +from wenet.branchformer.encoder_layer import BranchformerEncoderLayer +from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer +from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer +from wenet.paraformer.layers import AliParaformerEncoderLayer, SanmDecoderLayer +from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer +from wenet.transformer.encoder_layer import (ConformerEncoderLayer, + TransformerEncoderLayer) +from wenet.transformer.decoder_layer import DecoderLayer +from wenet.utils.checkpoint import save_state_dict_and_infos +from wenet.utils.init_model import WENET_DECODER_CLASSES, WENET_ENCODER_CLASSES + +WENET_ENCODER_LAYERS_CLASSES = { + 'transformer_encoder_layer': TransformerEncoderLayer, + 'conformer_encoder_layer': ConformerEncoderLayer, + 'paraformer_encoder_layer': AliParaformerEncoderLayer, + 'squeezeformer_encoder_layer': SqueezeformerEncoderLayer, + 'ebranchformer_encoder_layer': EBranchformerEncoderLayer, + 'efficient_conformer_encoder_layer': StrideConformerEncoderLayer, + 'branchformer_encoder_layer': BranchformerEncoderLayer, +} + +WENET_DECODER_LAYERS_CLASSES = { + 'transformer_decoder_layer': DecoderLayer, + 'paraformer_decoder_layer': SanmDecoderLayer, + # TODO(Mddct): + # 1 wrap transducer's predictor and joint + # 2 wrap paraformer's cif and ignore lstm +} + + +def wenet_fsdp_wrap_policy(mode): + # different wrap methods + # please refer: https://openmmlab.medium.com/its-2023-is-pytorch-s-fsdp-the-best-choice-for-training-large-models-fe8d2848832f # noqa + assert mode in ['no_shard', 'model', 'zero2', 'zero3'] + if mode == 'no_shard': + return None + else: + # TODO(Mddct): Support user customization + # see more wrap methods: + # https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py#L13 # noqa + if mode == 'model': + enc_dec_wrap_policy = partial( + lambda_auto_wrap_policy, + lambda_fn=lambda module: isinstance( + module, + tuple(WENET_ENCODER_CLASSES.values()) + tuple( + WENET_DECODER_CLASSES.values()))) + return enc_dec_wrap_policy + else: + to_wrap_class = set() + to_wrap_class.update(set(WENET_ENCODER_LAYERS_CLASSES.values())) + to_wrap_class.update(set(WENET_DECODER_LAYERS_CLASSES.values())) + layers_wrap_policy = partial(transformer_auto_wrap_policy, + transformer_layer_cls=to_wrap_class) + return layers_wrap_policy + + +fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, + rank0_only=True) + + +def fsdp_save_model(model, save_model_path, info_dict): + # TODO(Mddct); When the model is large, saving a model will take a long time. + # We only need to keep the sharding in an asynchronous manner, but it is + # good now. This feature will be supported when llm is supported in the future. + + rank = int(os.environ.get('RANK', 0)) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, + fullstate_save_policy): + state_dict = model.state_dict() + if rank == 0: + save_state_dict_and_infos(state_dict, save_model_path, info_dict) + + +def check_gradient_checkpoint(model): + ckpt_laye_types = [] + if hasattr(model, 'encoder') and hasattr(model.encoder, + 'gradient_checkpointing'): + if model.encoder.gradient_checkpointing: + model.encoder.gradient_checkpointing = False + ckpt_laye_types += list(WENET_ENCODER_LAYERS_CLASSES.values()) + if hasattr(model, 'decoder') and hasattr(model.decoder, + 'gradient_checkpointing'): + if model.decoder.gradient_checkpointing: + model.decoder.gradient_checkpointing = False + ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values()) + if isinstance(model.decoder, DecoderOnly): + ckpt_laye_types += [DecoderOnly] + return tuple(ckpt_laye_types) + + +def apply_fsdp_checkpointing(model, ckpt_layer_types: tuple): + # NOTE(Mddct): torch.utils.checkpoint is currently incompatible with + # wenet's model mode. Using this writing method, Please refer to + # https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/policies/activation_checkpointing_functions.py#L21 # noqa + if len(ckpt_layer_types) == 0: + return + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointImpl, + apply_activation_checkpointing, + ) + non_reentrant_wrapper = partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=non_reentrant_wrapper, + check_fn=lambda submodule: isinstance(submodule, ckpt_layer_types)) diff --git a/wenet/utils/init_dataset.py b/wenet/utils/init_dataset.py new file mode 100644 index 0000000000..60a9c066dd --- /dev/null +++ b/wenet/utils/init_dataset.py @@ -0,0 +1,19 @@ +from typing import Optional +from wenet.dataset.dataset import Dataset as ASRDatast +from wenet.dataset.llm_dataset import Dataset as LLMDataset +from wenet.text.base_tokenizer import BaseTokenizer + + +def init_dataset(data_type, + data_list_file, + conf, + tokenizer: Optional[BaseTokenizer] = None, + partition=True, + dataset_type: str = 'asr'): + assert dataset_type in ['asr', 'llm'] + if dataset_type == 'asr': + return ASRDatast(data_type, data_list_file, tokenizer, conf, partition) + else: + assert tokenizer is not None + return LLMDataset(data_type, data_list_file, tokenizer, conf, + partition) diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 1255fb5566..a771dfe5e7 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch +from wenet.finetune.lora.utils import mark_only_lora_as_trainable from wenet.k2.model import K2Model from wenet.paraformer.cif import Cif from wenet.paraformer.layers import SanmDecoder, SanmEncoder from wenet.paraformer.paraformer import Paraformer, Predictor +from wenet.LLM.causal_model import CausalLM +from wenet.LLM.decoder import DecoderOnly from wenet.transducer.joint import TransducerJoint from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, RNNPredictor) @@ -36,6 +40,8 @@ from wenet.whisper.whisper import Whisper from wenet.utils.cmvn import load_cmvn from wenet.utils.checkpoint import load_checkpoint, load_trained_modules +from wenet.finetune.lora.encoder import (LoRATransformerEncoder, + LoRAConformerEncoder) WENET_ENCODER_CLASSES = { "transformer": TransformerEncoder, @@ -47,6 +53,8 @@ "dual_transformer": DualTransformerEncoder, "dual_conformer": DualConformerEncoder, 'sanm_encoder': SanmEncoder, + "lora_transformer": LoRATransformerEncoder, + "lora_conformer": LoRAConformerEncoder, } WENET_DECODER_CLASSES = { @@ -78,11 +86,11 @@ "k2_model": K2Model, "transducer": Transducer, 'paraformer': Paraformer, + 'causal_llm': CausalLM, } -def init_model(args, configs): - +def init_speech_model(args, configs): # TODO(xcsong): Forcefully read the 'cmvn' attribute. if configs.get('cmvn', None) == 'global_cmvn': mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'], @@ -100,6 +108,9 @@ def init_model(args, configs): decoder_type = configs.get('decoder', 'bitransformer') ctc_type = configs.get('ctc', 'ctc') + if hasattr(args, 'use_lora') and args.use_lora: + encoder_type = "lora_" + encoder_type + encoder = WENET_ENCODER_CLASSES[encoder_type]( input_dim, global_cmvn=global_cmvn, @@ -159,6 +170,32 @@ def init_model(args, configs): special_tokens=configs.get('tokenizer_conf', {}).get('special_tokens', None), **configs['model_conf']) + return model, configs + + +def init_causal_llm(configs): + vocab_size = configs['output_dim'] + assert configs['decoder'] == 'decoder_only' + assert configs['model'] == 'causal_lm' + decoder_only = DecoderOnly(**configs['decoder_conf']) + + model = CausalLM( + vocab_size, + decoder_only, + **configs['model_conf'], + special_tokens=configs.get('tokenizer_conf', + {}).get('special_tokens', None), + ) + return model, configs + + +def init_model(args, configs): + + model_type = configs.get('model', 'asr_model') + if model_type == 'causal_lm': + model, configs = init_causal_llm(configs) + else: + model, configs = init_speech_model(args, configs) # If specify checkpoint, load some info from checkpoint if hasattr(args, 'checkpoint') and args.checkpoint is not None: @@ -168,10 +205,13 @@ def init_model(args, configs): else: infos = {} configs["init_infos"] = infos + print(configs) + # Trye to tie some weights + if hasattr(model, 'tie_or_clone_weights'): + model.tie_or_clone_weights(args.jit) - # Tie emb.weight to decoder.output_layer.weight - if model.decoder.tie_word_embedding: - model.decoder.tie_or_clone_weights(jit_mode=args.jit) + if hasattr(args, 'only_optimize_lora') and args.only_optimize_lora: + mark_only_lora_as_trainable(model, bias='lora_only') return model, configs diff --git a/wenet/utils/init_tokenizer.py b/wenet/utils/init_tokenizer.py index c0c2ce7d77..9f42f058a3 100644 --- a/wenet/utils/init_tokenizer.py +++ b/wenet/utils/init_tokenizer.py @@ -18,6 +18,7 @@ from wenet.text.base_tokenizer import BaseTokenizer from wenet.text.bpe_tokenizer import BpeTokenizer from wenet.text.char_tokenizer import CharTokenizer +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer from wenet.text.paraformer_tokenizer import ParaformerTokenizer from wenet.text.whisper_tokenizer import WhisperTokenizer @@ -47,6 +48,9 @@ def init_tokenizer(configs) -> BaseTokenizer: tokenizer = ParaformerTokenizer( symbol_table=configs['tokenizer_conf']['symbol_table_path'], seg_dict=configs['tokenizer_conf']['seg_dict_path']) + elif tokenizer_type == 'huggingface': + tokenizer = HuggingFaceTokenizer( + model=configs['tokenizer_conf']['model']) else: raise NotImplementedError logging.info("use {} tokenizer".format(configs["tokenizer"])) diff --git a/wenet/utils/rope_utils.py b/wenet/utils/rope_utils.py index e80bf9ace7..54f13c47b8 100644 --- a/wenet/utils/rope_utils.py +++ b/wenet/utils/rope_utils.py @@ -31,3 +31,9 @@ def llama_apply_rotary_emb(x: torch.Tensor, x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) return x_out.type_as(x) + + +WENET_APPLY_ROTARY_EMB = { + 'google': google_apply_rotary_emb, + 'llama': llama_apply_rotary_emb, +} diff --git a/wenet/utils/scheduler.py b/wenet/utils/scheduler.py index 6a78bb6c7e..170e4fd1db 100644 --- a/wenet/utils/scheduler.py +++ b/wenet/utils/scheduler.py @@ -15,7 +15,7 @@ # Modified from ESPnet(https://github.com/espnet/espnet) # NeMo(https://github.com/NVIDIA/NeMo) -from typing import Union +from typing import List, Union import math import warnings @@ -43,11 +43,10 @@ class WarmupLR(_LRScheduler): def __init__( self, optimizer: torch.optim.Optimizer, - warmup_steps: Union[int, float] = 25000, + warmup_steps: Union[int, float, List[Union[int, float]]] = 25000, last_epoch: int = -1, ): self.warmup_steps = warmup_steps - # __init__() must be invoked before setting field # because step() is also invoked in __init__() super().__init__(optimizer, last_epoch) @@ -57,14 +56,21 @@ def __repr__(self): def get_lr(self): step_num = self.last_epoch + 1 - if self.warmup_steps == 0: - return [lr * step_num**-0.5 for lr in self.base_lrs] - else: - return [ - lr * self.warmup_steps**0.5 * - min(step_num**-0.5, step_num * self.warmup_steps**-1.5) - for lr in self.base_lrs - ] + warmup_steps = self.warmup_steps + if not isinstance(warmup_steps, List): + warmup_steps = [self.warmup_steps] * len(self.base_lrs) + + def initlr_fn(lr): + return lr * step_num**-0.5 + + def warmuplr_fn(lr, warmup_step): + return lr * warmup_step**0.5 * min(step_num**-0.5, + step_num * warmup_step**-1.5) + + return [ + initlr_fn(lr) if warmup_steps[i] == 0 else warmuplr_fn( + lr, warmup_steps[i]) for (i, lr) in enumerate(self.base_lrs) + ] def set_step(self, step: int): self.last_epoch = step diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index d3d06f57ac..f8f9fa095c 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import nullcontext import copy -from typing import Optional +from typing import List, Optional + import deepspeed import json import logging @@ -28,17 +30,24 @@ from tensorboardX import SummaryWriter from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ +from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP, + CPUOffload, MixedPrecision, + sharded_grad_scaler, ShardingStrategy) from deepspeed.runtime.zero.stage_1_and_2 import ( estimate_zero2_model_states_mem_needs_all_live) from deepspeed.runtime.zero.stage3 import ( estimate_zero3_model_states_mem_needs_all_live) from deepspeed.utils.zero_to_fp32 import ( convert_zero_checkpoint_to_fp32_state_dict) -from wenet.dataset.dataset import Dataset from wenet.utils.checkpoint import save_checkpoint -from wenet.utils.common import StepTimer +from wenet.utils.common import (StepTimer, get_nested_attribute, lrs_to_str, + tensor_to_scalar) +from wenet.utils.fsdp_utils import (check_gradient_checkpoint, fsdp_save_model, + apply_fsdp_checkpointing, + wenet_fsdp_wrap_policy) from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing from wenet.utils.ctc_utils import get_blank_id +from wenet.utils.init_dataset import init_dataset def add_model_args(parser): @@ -105,6 +114,34 @@ def add_dataset_args(parser): return parser +def add_lora_args(parser): + parser.add_argument("--use_lora", + default=False, + type=bool, + help="whether use the lora finetune.") + parser.add_argument("--only_optimize_lora", + default=False, + type=bool, + help="freeze all other paramters and only optimize \ + LoRA-related prameters.") + parser.add_argument("--lora_list", + default=['o', 'q', 'k', 'v'], + help="lora module list.") + parser.add_argument("--lora_rank", + default=8, + type=int, + help="lora rank num.") + parser.add_argument("--lora_alpha", + default=8, + type=int, + help="lora scale param, scale=lora_alpha/lora_rank.") + parser.add_argument("--lora_dropout", + default=0, + type=float, + help="lora dropout param.") + return parser + + def add_ddp_args(parser): parser.add_argument('--ddp.dist_backend', dest='dist_backend', @@ -142,13 +179,48 @@ def add_deepspeed_args(parser): return parser +def add_fsdp_args(parser): + parser.add_argument( + '--dtype', + default='fp32', + choices=['fp32', 'fp16', 'bf16'], + help='when amp is used, dtype is automatically set to fp16.\ + this arg has no effect when deepspeed is enabled.') + parser.add_argument( + '--fsdp_cpu_offload', + default=False, + type=bool, + help='whether to offload parameters to CPU', + ) + parser.add_argument( + '--fsdp_sync_module_states', + type=bool, + default=True, + help='\ + each FSDP module will broadcast module parameters and buffers from \ + rank 0 to ensure that they are replicated across ranks', + ) + parser.add_argument( + '--fsdp_sharding_strategy', + default='zero2', + # TODO(Mddct): pipeline and model parallel (3-D parallelism) + choices=['no_shard', 'model', 'zero2', 'zero3'], + help='Sharding strategy for FSDP. Choose from the following options:\n' + ' - "no_shard": Equivalent to DistributedDataParallel (DDP).\n' + ' - "model": WENET_ENC_DEC strategy, equivalent to DeepSpeed zero1.\n' + ' - "zero2": SHARD_GRAD_OP strategy, equivalent to DeepSpeed zero2.\n' + ' - "zero3": FULL_SHARD strategy, equivalent to DeepSpeed zero3.\n' + 'For more information, refer to the FSDP API documentation.') + return parser + + def init_distributed(args): world_size = int(os.environ.get('WORLD_SIZE', 1)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) logging.info('training on multiple gpus, this gpu {}'.format(local_rank) + ', rank {}, world_size {}'.format(rank, world_size)) - if args.train_engine == "torch_ddp": + if args.train_engine in ["torch_ddp", "torch_fsdp"]: torch.cuda.set_device(local_rank) dist.init_process_group(args.dist_backend) elif args.train_engine == "deepspeed": @@ -159,11 +231,12 @@ def init_distributed(args): def check_modify_and_save_config(args, configs, symbol_table): - if args.train_engine == "torch_ddp": + if args.train_engine in ["torch_ddp", "torch_fsdp"]: if args.use_amp: configs["dtype"] = "fp16" + args.dtype = 'fp16' else: - configs["dtype"] = "fp32" + configs["dtype"] = args.dtype elif args.train_engine == "deepspeed": # NOTE(xcsong): DeepSpeed does not support uneven data. When using custom # dataset, we need to manually ensure that the data is evenly distributed @@ -203,20 +276,28 @@ def check_modify_and_save_config(args, configs, symbol_table): assert ds_configs["gradient_clipping"] == configs['grad_clip'] assert ds_configs["steps_per_print"] == configs['log_interval'] - if 'input_dim' not in configs: - if 'fbank_conf' in configs['dataset_conf']: - input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] - elif 'log_mel_spectrogram_conf' in configs['dataset_conf']: - input_dim = configs['dataset_conf']['log_mel_spectrogram_conf'][ - 'num_mel_bins'] + if args.use_lora: + configs['encoder_conf']['lora_list'] = args.lora_list + configs['encoder_conf']['lora_rank'] = args.lora_rank + configs['encoder_conf']['lora_alpha'] = args.lora_alpha + configs['encoder_conf']['lora_dropout'] = args.lora_dropout + if configs["model"] == 'asr': + if 'input_dim' not in configs: + if 'fbank_conf' in configs['dataset_conf']: + input_dim = configs['dataset_conf']['fbank_conf'][ + 'num_mel_bins'] + elif 'log_mel_spectrogram_conf' in configs['dataset_conf']: + input_dim = configs['dataset_conf'][ + 'log_mel_spectrogram_conf']['num_mel_bins'] + else: + input_dim = configs['dataset_conf']['mfcc_conf'][ + 'num_mel_bins'] else: - input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins'] - else: - input_dim = configs['input_dim'] + input_dim = configs['input_dim'] - configs, _ = get_blank_id(configs, symbol_table) + configs, _ = get_blank_id(configs, symbol_table) - configs['input_dim'] = input_dim + configs['input_dim'] = input_dim configs['output_dim'] = configs['vocab_size'] configs['train_engine'] = args.train_engine @@ -256,13 +337,18 @@ def init_dataset_and_dataloader(args, configs, tokenizer, seed=777): cv_conf['list_shuffle'] = False configs['vocab_size'] = tokenizer.vocab_size() - train_dataset = Dataset(args.data_type, args.train_data, tokenizer, - train_conf, True) - cv_dataset = Dataset(args.data_type, - args.cv_data, - tokenizer, - cv_conf, - partition=False) + train_dataset = init_dataset(args.data_type, + args.train_data, + train_conf, + tokenizer, + True, + dataset_type=configs['dataset']) + cv_dataset = init_dataset(args.data_type, + args.cv_data, + cv_conf, + tokenizer, + partition=False, + dataset_type=configs['dataset']) # NOTE(xcsong): Why we prefer persistent_workers=True ? # https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110 @@ -283,25 +369,19 @@ def init_dataset_and_dataloader(args, configs, tokenizer, seed=777): return train_dataset, cv_dataset, train_data_loader, cv_data_loader -def wrap_cuda_model(args, model): +def wrap_cuda_model(args, model, configs=None): local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) world_size = int(os.environ.get('WORLD_SIZE', 1)) if hasattr(model, 'encoder'): grad_ckpt = getattr(model.encoder, 'gradient_checkpointing', False) else: grad_ckpt = False - # TODO(xcsong): could one GPU use ddp? and int(os.environ.get('WORLD_SIZE', 1)) > 1 if args.train_engine == "torch_ddp": # native pytorch ddp assert (torch.cuda.is_available()) model.cuda() model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=not grad_ckpt) device = torch.device("cuda") - if args.fp16_grad_sync: - from torch.distributed.algorithms.ddp_comm_hooks import ( - default as comm_hooks, ) - model.register_comm_hook(state=None, - hook=comm_hooks.fp16_compress_hook) elif args.train_engine == "deepspeed": # deepspeed # NOTE(xcsong): look in detail how the memory estimator API works: # https://deepspeed.readthedocs.io/en/latest/memory.html#discussion @@ -318,17 +398,87 @@ def wrap_cuda_model(args, model): num_nodes=world_size // local_world_size) device = None # Init device later pass # Init DeepSpeed later + elif args.train_engine == 'torch_fsdp': + assert configs is not None + mixed_precision_dtype = { + 'fp32': torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + }[configs['dtype']] + + sharding_strategy = { + 'model': ShardingStrategy.SHARD_GRAD_OP, + 'zero2': ShardingStrategy.SHARD_GRAD_OP, + 'zero3': ShardingStrategy.FULL_SHARD, + 'no_shard': ShardingStrategy.NO_SHARD, + }[args.fsdp_sharding_strategy] + wrap_policy = wenet_fsdp_wrap_policy(mode=args.fsdp_sharding_strategy) + layer_types = check_gradient_checkpoint(model) + model = FSDP( + model, + auto_wrap_policy=wrap_policy, + cpu_offload=CPUOffload(offload_params=True) + if args.fsdp_cpu_offload is True else None, + mixed_precision=MixedPrecision( + param_dtype=mixed_precision_dtype, + reduce_dtype=mixed_precision_dtype, + buffer_dtype=mixed_precision_dtype, + ), + sharding_strategy=sharding_strategy, + limit_all_gathers=True, + use_orig_params=True, + sync_module_states=args.fsdp_sync_module_states, + # init_distributed is called (torch.cuda.set_device), + # we should set device_id, see FSDP api + device_id=torch.cuda.current_device(), + ) + apply_fsdp_checkpointing(model, layer_types) + device = torch.device("cuda") else: logging.error("not supported engine: {}".format(args.train_engine)) + if args.train_engine in ["torch_fsdp", "torch_ddp"]: + if args.fp16_grad_sync: + from torch.distributed.algorithms.ddp_comm_hooks import ( + default as comm_hooks, ) + model.register_comm_hook(state=None, + hook=comm_hooks.fp16_compress_hook) return model, device def init_optimizer_and_scheduler(args, configs, model): + groups = [] + lr = configs['optim_conf'].get('lr') + if isinstance(lr, List): + assert configs['scheduler'] == 'warmuplr' + modules_m = configs['optim_conf']['modules'] + assert isinstance(modules_m, List) + assert len(modules_m) + 1 == len(lr) + special_param_ids = set() + rest_params = [] + for (i, m_str) in enumerate(modules_m): + sub_module = get_nested_attribute(model, m_str) + subs_params = [] + for _, sub_params in sub_module.named_parameters(): + subs_params.append(sub_params) + special_param_ids.add(id(sub_params)) + groups.append({'params': subs_params, 'lr': lr[i]}) + # other model's parameters + for _, param in model.named_parameters(): + if id(param) not in special_param_ids: + rest_params.append(param) + groups.append({'params': rest_params, 'lr': lr[-1]}) + + params = groups if len(groups) > 0 else model.parameters() + optim_conf = copy.deepcopy(configs['optim_conf']) + if 'modules' in optim_conf: + del optim_conf['modules'] + if isinstance(lr, List): + optim_conf['lr'] = lr[-1] if configs['optim'] == 'adam': - optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) + optimizer = optim.Adam(params, **optim_conf) elif configs['optim'] == 'adamw': - optimizer = optim.AdamW(model.parameters(), **configs['optim_conf']) + optimizer = optim.AdamW(params, **optim_conf) else: raise ValueError("unknown optimizer: " + configs['optim']) @@ -396,10 +546,23 @@ def init_summarywriter(args): return writer +def init_scaler(args): + scaler = None + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + elif args.train_engine == 'torch_fsdp': + # why bf16 don't need scaler: + # https://discuss.pytorch.org/t/why-bf16-do-not-need-loss-scaling/176596 + if args.dtype in ['fp16']: + scaler = sharded_grad_scaler.ShardedGradScaler(enabled=True) + return scaler + + def save_model(model, info_dict): rank = int(os.environ.get('RANK', 0)) tag = info_dict["tag"] model_dir = info_dict["model_dir"] + save_model_path = os.path.join(model_dir, '{}.pt'.format(tag)) # save ckpt if info_dict["train_engine"] == "deepspeed": # NOTE(xcsong): All ranks should call this API, but only rank 0 @@ -411,13 +574,14 @@ def save_model(model, info_dict): client_state=info_dict) if info_dict["save_states"] == "model_only" and rank == 0: convert_zero_checkpoint_to_fp32_state_dict(model_dir, - "{}/{}.pt".format( - model_dir, tag), + save_model_path, tag=tag) os.system("rm -rf {}/{}".format(model_dir, tag)) + + elif info_dict['train_engine'] == "torch_fsdp": + fsdp_save_model(model, save_model_path, info_dict) elif rank == 0: # NOTE(xcsong): For torch_ddp, only rank-0 should call this. - save_model_path = os.path.join(model_dir, '{}.pt'.format(tag)) save_checkpoint(model, save_model_path, info_dict) # save yaml if rank == 0: @@ -468,21 +632,24 @@ def batch_forward(model, batch, scaler, info_dict): else: # fp32 dtype = None - if train_engine == "deepspeed": - # deepspeed - with torch.cuda.amp.autocast(enabled=dtype is not None, - dtype=dtype, - cache_enabled=False): - loss_dict = model(batch, device) - else: - # torch_ddp - # autocast context - # The more details about amp can be found in - # https://pytorch.org/docs/stable/notes/amp_examples.html - with torch.cuda.amp.autocast(scaler is not None): - loss_dict = model(batch, device) - info_dict['loss_dict'] = loss_dict + # autocast context + # The more details about amp can be found in + # https://pytorch.org/docs/stable/notes/amp_examples.html + autocast = { + "deepspeed": + torch.cuda.amp.autocast(enabled=dtype is not None, + dtype=dtype, + cache_enabled=False), + "torch_ddp": + torch.cuda.amp.autocast(enabled=scaler is not None), + "torch_fsdp": + torch.cuda.amp.autocast(enabled=True, dtype=dtype) + if dtype is not None else nullcontext() + }[train_engine] + with autocast: + loss_dict = model(batch, device) + info_dict['loss_dict'] = loss_dict return info_dict @@ -499,16 +666,21 @@ def batch_backward(model, scaler, info_dict): # `scale_loss_wrt_accum_grad + loss.backward()` # ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api scaled_loss = model.backward(loss) - elif train_engine == "torch_ddp": + else: + assert train_engine in ["torch_ddp", "torch_fsdp"] scaled_loss = loss / accum_grad - if use_amp: + if scaler is not None: + # fp16 (amp and fsdp) scaler.scale(scaled_loss).backward() else: + # float32 (ddp and fsdp) + # bf16 (fsdp) scaled_loss.backward() + info_dict['loss_dict']['loss'] = scaled_loss for loss_name, loss_value in info_dict['loss_dict'].items(): if loss_value is not None: - info_dict['loss_dict'][loss_name] = loss_value.item() + info_dict['loss_dict'][loss_name] = tensor_to_scalar(loss_value) return info_dict @@ -540,9 +712,14 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): grad_norm = model.get_global_grad_norm() elif (batch_idx + 1) % accum_grad == 0: # Use mixed precision training - if use_amp: + # fp16 (ddp fsdp) + if scaler is not None: scaler.unscale_(optimizer) - grad_norm = clip_grad_norm_(model.parameters(), clip) + if train_engine == "torch_ddp": + grad_norm = clip_grad_norm_(model.parameters(), clip) + else: + # fsdp + grad_norm = model.clip_grad_norm_(clip) # Must invoke scaler.update() if unscale_() is used in # the iteration to avoid the following error: # RuntimeError: unscale_() has already been called @@ -553,15 +730,17 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): scaler.step(optimizer) scaler.update() else: - grad_norm = clip_grad_norm_(model.parameters(), clip) + if train_engine == "torch_ddp": + grad_norm = clip_grad_norm_(model.parameters(), clip) + else: + grad_norm = model.clip_grad_norm_(clip) if torch.isfinite(grad_norm): optimizer.step() optimizer.zero_grad() scheduler.step() - grad_norm = grad_norm.item() - info_dict["lr"] = optimizer.param_groups[0]['lr'] - info_dict["grad_norm"] = grad_norm + info_dict["lrs"] = [group['lr'] for group in optimizer.param_groups] + info_dict["grad_norm"] = tensor_to_scalar(grad_norm) return info_dict @@ -575,27 +754,40 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None): train_engine = info_dict.get("train_engine", "torch_ddp") accum_grad = info_dict.get('accum_grad', 1) if tag != "CV" else 1 log_interval = info_dict.get('log_interval', 10) - lr = info_dict.get("lr", 0.0) + lrs = info_dict.get("lrs", [0.0]) is_gradient_accumulation_boundary = info_dict.get( "is_gradient_accumulation_boundary", False) rank = int(os.environ.get('RANK', 0)) - + # TRAIN Tensorboard if tag == "TRAIN" and rank == 0 and writer is not None: - if (train_engine == "deepspeed" and is_gradient_accumulation_boundary) or \ - (train_engine == "torch_ddp" and (batch_idx + 1) % accum_grad == 0): + if (train_engine == "deepspeed" and is_gradient_accumulation_boundary + ) or (train_engine in ["torch_ddp", "torch_fsdp"] and + (batch_idx + 1) % accum_grad == 0): writer.add_scalar('train/train_loss', - loss_dict['loss'] * accum_grad, step + 1) - writer.add_scalar('train/grad_norm', info_dict['grad_norm'], - step + 1) + tensor_to_scalar(loss_dict['loss']) * accum_grad, + step) + writer.add_scalar('train/grad_norm', info_dict['grad_norm'], step) for name, value in loss_dict.items(): if name != 'loss' and value is not None: - writer.add_scalar('train/{}'.format(name), value, step + 1) + writer.add_scalar('train/{}'.format(name), + tensor_to_scalar(value), step) + # lr + for i, lr in enumerate(lrs): + writer.add_scalar('train/lr_{}'.format(i), lr, step) + # CV Tensorboard elif "step_" in tag and rank == 0 and writer is not None: - writer.add_scalar('global_step/lr', lr, step + 1) for name, value in loss_dict.items(): - writer.add_scalar('global_step/{}'.format(name), value, step + 1) - + writer.add_scalar('cv/{}'.format(name), tensor_to_scalar(value), + step) + logging.info( + 'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format( + epoch, step + 1, lrs_to_str(lrs), + tensor_to_scalar(loss_dict["loss"]), rank, + tensor_to_scalar(loss_dict["acc"]))) + return + + # TRAIN & CV, Shell log (stdout) if (batch_idx + 1) % log_interval == 0: log_str = '{} | '.format(tag) if timer is not None: @@ -605,25 +797,35 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None): steps_per_second = timer.steps_per_second(timer_step) log_str += 'steps/sec {:.1f}| '.format(steps_per_second) log_str += 'Batch {}/{} loss {:.6f} '.format( - epoch, - batch_idx + 1 if 'save_interval' not in info_dict else step + 1, - loss_dict['loss'] * accum_grad) + epoch, batch_idx + 1 if 'save_interval' not in info_dict else + (step + 1) * accum_grad, + tensor_to_scalar(loss_dict['loss']) * accum_grad) for name, value in loss_dict.items(): if name != 'loss' and value is not None: - log_str += '{} {:.6f} '.format(name, value) + log_str += '{} {:.6f} '.format(name, tensor_to_scalar(value)) if tag == "TRAIN": - log_str += 'lr {:.8f} grad_norm {:.6f} rank {}'.format( - lr, info_dict['grad_norm'], rank) + log_str += 'lr {} grad_norm {:.6f} rank {}'.format( + lrs_to_str(lrs), info_dict['grad_norm'], rank) logging.debug(log_str) def log_per_epoch(writer, info_dict): epoch = info_dict["epoch"] loss_dict = info_dict["loss_dict"] + lrs = info_dict['lrs'] + rank = int(os.environ.get('RANK', 0)) + step = info_dict["step"] + logging.info( + 'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format( + epoch, step, lrs_to_str(lrs), tensor_to_scalar(loss_dict["loss"]), + rank, tensor_to_scalar(loss_dict["acc"]))) + if int(os.environ.get('RANK', 0)) == 0: - writer.add_scalar('epoch/lr', info_dict["lr"], epoch) + for i, lr in enumerate(info_dict["lrs"]): + writer.add_scalar('epoch/lr_{}'.format(i), lr, epoch) for name, value in loss_dict.items(): - writer.add_scalar('epoch/{}'.format(name), value, epoch) + writer.add_scalar('epoch/{}'.format(name), tensor_to_scalar(value), + epoch) def freeze_modules(model, args): diff --git a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py index 4c280e19f2..fff096db19 100755 --- a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py +++ b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py @@ -166,6 +166,8 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['dataset_conf']['batch_conf']['batch_type'] = 'dynamic' configs['dataset_conf']['batch_conf']['batch_size'] = 26 configs['dataset_conf']['batch_conf']['max_frames_in_batch'] = 12000 + configs['dataset_conf']['language_conf'] = {} + configs['dataset_conf']['language_conf']['limited_langs'] = ['zh'] configs['grad_clip'] = 5 configs['accum_grad'] = 4