From f856535a3b3fdc5b77ef534257f671be9f1109bc Mon Sep 17 00:00:00 2001 From: zhutong Date: Fri, 3 Nov 2023 18:51:36 +0800 Subject: [PATCH 1/4] update consumed_tokens api --- .vscode/launch.json | 2 +- smoe/callbacks/tensorboard.py | 6 ++ smoe/data/streaming.py | 145 +++++++++++++++++++++------- smoe/trainer/llama_lr_scheduling.py | 25 +++-- smoe/utils/io.py | 5 + smoe/utils/tokenize.py | 8 +- smoe/utils/vars.py | 1 + tests/data/test_streaming.py | 22 +++-- 8 files changed, 160 insertions(+), 54 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 1f0445f..c70d9a0 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,7 +9,7 @@ "type": "python", "request": "attach", "connect": { - "host": "SH-IDCA1404-10-140-54-73", + "host": "SH-IDCA1404-10-140-54-50", "port": 5678 }, "pathMappings": [ diff --git a/smoe/callbacks/tensorboard.py b/smoe/callbacks/tensorboard.py index ca9aa9f..1bb0585 100644 --- a/smoe/callbacks/tensorboard.py +++ b/smoe/callbacks/tensorboard.py @@ -93,4 +93,10 @@ def on_log( self.tb_writer.add_image( k, get_heatmap_img_grid_for_tb(v), state.global_step ) + elif k == "train/consumed_tokens": + v.update({"total_tokens": sum(v.values())}) + for name, val in v.items(): + self.tb_writer.add_scalar( + f"consumed_tokens/{name}", val, state.global_step + ) self.tb_writer.flush() diff --git a/smoe/data/streaming.py b/smoe/data/streaming.py index cd9b1c9..689779e 100644 --- a/smoe/data/streaming.py +++ b/smoe/data/streaming.py @@ -6,9 +6,11 @@ """ import random +from collections import defaultdict from pathlib import Path -from typing import Iterator +from typing import Any, Callable, Iterable, Iterator +import numpy as np import torch from torch.utils.data import Dataset, IterableDataset @@ -16,7 +18,7 @@ from smoe.utils.io import load_jsonlines, load_jsonlines_iter from smoe.utils.logging import get_logger from smoe.utils.random_utils import get_random_string -from smoe.utils.vars import JSONL_DATASET_CACHE_NAME +from smoe.utils.vars import JSONL_DATASET_CACHE_NAME, META_SUFFIX logger = get_logger(__file__) @@ -217,6 +219,29 @@ def __len__(self): return len(self.cached) +def batchify_loader(dataset: Iterable, batch_size: int, collate_fn: Callable): + batch = [] + for ins in dataset: + batch.append(ins) + if len(batch) >= batch_size: + yield collate_fn(batch) + batch.clear() + if len(batch) > 0: + yield collate_fn(batch) + batch.clear() + + +class BufferAggregation: + def __init__(self, block_size: int) -> None: + self.block_size = block_size + + def __call__(self, buffer) -> Any: + results = buffer + if self.block_size > 0 and len(buffer) > 0: + results = group_instances(buffer, self.block_size) + return results + + class PackedJsonlDataset(IterableDataset): def __init__( self, @@ -224,51 +249,82 @@ def __init__( seed: int = 1227, buffer_size: int = 200, block_size: int = 2048, + skip_tokens: int = 0, ) -> None: super().__init__() self.data_dir = data_dir self.rng = random.Random(seed) self.buffer_size = buffer_size self.block_size = block_size + self.skip_tokens = skip_tokens data_dir_path = Path(data_dir) filepaths = sorted(data_dir_path.glob("**/*.jsonl")) self.rng.shuffle(filepaths) self.filepaths = filepaths - self.visited_filepaths = [] + self.curr_filepath_pointer = -1 + self.consumed_tokens: int = 0 + + self.buffer_aggregation = BufferAggregation(self.block_size) + + def next_filepath(self) -> str: + if len(self.filepaths) == 0: + raise RuntimeError(f"There's no filepath in {self.data_dir}") + + self.curr_filepath_pointer += 1 + if self.curr_filepath_pointer >= len(self.filepaths): + self.curr_filepath_pointer = 0 + + num_skipped_filepath = 0 + while self.consumed_tokens < self.skip_tokens: + filepath = self.filepaths[self.curr_filepath_pointer] + # meta: [[current token number in the whole file, length of the current instance], ...] + meta: np.ndarray = np.load(filepath + META_SUFFIX) + curr_filepath_tokens = meta.sum(axis=0)[1] + if self.consumed_tokens + curr_filepath_tokens > self.skip_tokens: + break + self.consumed_tokens += curr_filepath_tokens + self.curr_filepath_pointer += 1 + if self.curr_filepath_pointer >= len(self.filepaths): + self.curr_filepath_pointer = 0 + num_skipped_filepath += 1 + + if num_skipped_filepath > 0: + logger.info( + f"Skip {num_skipped_filepath} files," + f" {self.consumed_tokens} tokens," + f" remaining {self.skip_tokens - self.consumed_tokens} tokens to skip." + ) - self.buffer = [] + return self.filepaths[self.curr_filepath_pointer] def __iter__(self) -> Iterator: - self.buffer = [] - for filepath in self.filepaths: - logger.debug(f"Iter over jsonl file: {filepath}") - for ins in load_jsonlines_iter(filepath): - if self.buffer_size <= 1: + filepath = self.next_filepath() + logger.debug(f"Iter over jsonl file: {filepath}") + ds = load_jsonlines_iter(filepath) + # if self.consumed_tokens < self.skip_tokens: + # remaining_skip_tokens = self.skip_tokens - self.consumed_tokens + # # zhutong: here, the skip method is not perfect since there is batch grouping, + # # and the final token number per instance may be different. + # num_skip_lines = (meta[:, 1].cumsum() > remaining_skip_tokens).nonzero()[0][0] + # ds.skip_lines(num_skip_lines) + # self.consumed_tokens += meta[:num_skip_lines].sum(axis=0)[1] + for batch in batchify_loader(ds, self.buffer_size, self.buffer_aggregation): + for ins in batch: + if self.consumed_tokens >= self.skip_tokens: + self.consumed_tokens += len(ins["input_ids"]) yield ins - continue - - if len(self.buffer) >= self.buffer_size: - if len(self.buffer) > 0: - self.rng.shuffle(self.buffer) - self.buffer_aggregation() - yield from self.buffer - self.buffer.clear() - - self.buffer.append(ins) - self.visited_filepaths.append(filepath) - # for the last batch < buffer_size - if len(self.buffer) > 0: - self.rng.shuffle(self.buffer) - self.buffer_aggregation() - yield from self.buffer - self.buffer.clear() - - def buffer_aggregation(self): - if self.block_size > 0 and len(self.buffer) > 0: - results = group_instances(self.buffer, self.block_size) - self.buffer = results + def state_dict(self): + return { + "data_dir": self.data_dir, + "seed": self.seed, + "rng": self.rng.getstate(), + "buffer_size": self.buffer_size, + "block_size": self.block_size, + "filepaths": self.filepaths, + "consumed_tokens": self.consumed_tokens, + } class SubDirWeightedPackedJsonlDataset(IterableDataset): @@ -304,9 +360,12 @@ def __init__( seed: int = 1227, buffer_size: int = 200, block_size: int = 2048, + skip_tokens: dict = {}, ) -> None: self.rng = random.Random(seed) + self.seed = seed self.buffer_size = buffer_size + self.block_size = block_size self.dataset_dir_path = Path(dataset_dir) task_types = [p.stem for p in self.dataset_dir_path.glob("*") if p.is_dir()] @@ -322,6 +381,7 @@ def __init__( ) self.prob_map = prob_map + self.consumed_tokens = skip_tokens self.task_type_to_dataset = {} for task_type in task_types: # zhutong: use iter to support next() calling, since the dataset itself @@ -332,12 +392,25 @@ def __init__( seed=seed, buffer_size=buffer_size, block_size=block_size, + skip_tokens=skip_tokens.get(task_type, 0), ) ) self.task_type_to_dataset[task_type] = ds - def skip_tokens(self, skip_tokens: int): - raise NotImplementedError + def skip_tokens(self, skip_tokens: dict): + for task_type, num_skip_tokens in skip_tokens.items(): + self.task_type_to_dataset[task_type] = iter( + PackedJsonlDataset( + str(self.dataset_dir_path.joinpath(task_type)), + seed=self.seed, + buffer_size=self.buffer_size, + block_size=self.block_size, + skip_tokens=skip_tokens.get(task_type, 0), + ) + ) + if task_type not in self.consumed_tokens: + self.consumed_tokens[task_type] = 0 + self.consumed_tokens[task_type] += num_skip_tokens def __iter__(self) -> Iterator: while len(self.task_type_to_dataset) > 0: @@ -345,7 +418,11 @@ def __iter__(self) -> Iterator: weights = [self.prob_map[task_type] for task_type in candidate_task_types] choice = self.rng.choices(candidate_task_types, weights=weights, k=1)[0] try: - yield next(self.task_type_to_dataset[choice]) + ins = next(self.task_type_to_dataset[choice]) + if choice not in self.consumed_tokens: + self.consumed_tokens[choice] = 0 + self.consumed_tokens[choice] += len(ins["input_ids"]) + yield ins except StopIteration: # self.task_type_to_dataset.pop(choice) # logger.debug(f"Task type {choice} finished, drop it") diff --git a/smoe/trainer/llama_lr_scheduling.py b/smoe/trainer/llama_lr_scheduling.py index 65b38e3..b134857 100644 --- a/smoe/trainer/llama_lr_scheduling.py +++ b/smoe/trainer/llama_lr_scheduling.py @@ -5,7 +5,7 @@ import socket import sys import time -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from pathlib import Path from typing import Any, Dict, Union @@ -138,7 +138,7 @@ def _get_cosine_schedule_with_warmup_lr_lambda( class EnhancedTrainerState(TrainerState): # last Token/GPU/second timestamp start_timestamp: float = 0.0 - consumed_tokens: int = 0 + consumed_tokens: dict = field(default_factory=dict) class LlamaLrSchedulingTrainer(Trainer): @@ -146,7 +146,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.args: EnhancedTrainingArguments - self.state: EnhancedTrainerState + self.state: EnhancedTrainerState = EnhancedTrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + ) def create_optimizer(self): """ @@ -434,9 +437,13 @@ def get_train_dataloader(self) -> DataLoader: train_dataset = self.train_dataset data_collator = self.data_collator - # if not self.args.ignore_data_skip: - # skip_tokens = self.state.global_step * self.args.num_tokens_per_batch - # train_dataset.skip_tokens(skip_tokens) + # zhutong: update consumed_tokens + state_ctokens = sum(self.state.consumed_tokens.values()) + dataset_ctokens = sum(train_dataset.consumed_tokens.values()) + if state_ctokens > dataset_ctokens: + train_dataset.skip_tokens(state_ctokens) + # bind dataset recordings to state + self.state.consumed_tokens = train_dataset.consumed_tokens if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): train_dataset = self._remove_unused_columns( @@ -705,7 +712,7 @@ def _inner_training_loop( logger.info( f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch." - f" Total skip tokens: {self.state.consumed_tokens}" + f" Consumed tokens: {self.state.consumed_tokens}" ) # Update the references @@ -915,9 +922,7 @@ def _inner_training_loop( model.zero_grad() self.state.global_step += 1 - self.state.consumed_tokens = ( - self.state.global_step * self.args.num_tokens_per_batch - ) + self.state.consumed_tokens = self.train_dataset.consumed_tokens self.state.epoch = ( epoch + (step + 1 + steps_skipped) / steps_in_epoch ) diff --git a/smoe/utils/io.py b/smoe/utils/io.py index 69bf156..a7c96b0 100644 --- a/smoe/utils/io.py +++ b/smoe/utils/io.py @@ -60,6 +60,11 @@ def __init__(self, filepath, start_from: int = None) -> None: if start_from: self.fin.seek(start_from, os.SEEK_SET) + def skip_lines(self, num_skip_lines: int): + for i, _ in enumerate(self.fin, 1): + if i == num_skip_lines: + break + def tell(self): return self.fin.tell() diff --git a/smoe/utils/tokenize.py b/smoe/utils/tokenize.py index f9d4529..747ba8e 100644 --- a/smoe/utils/tokenize.py +++ b/smoe/utils/tokenize.py @@ -8,6 +8,8 @@ from tqdm import tqdm from transformers import AutoTokenizer +from smoe.utils.vars import META_SUFFIX + def get_parser(): parser = argparse.ArgumentParser() @@ -70,10 +72,10 @@ def prepare_meta(jsonl_filepath: str): ins = json.loads(line) length = len(ins["input_ids"]) meta.append((cur, length)) - cur += len(line) + cur += length # define path of the generated meta file - meta_fp = jsonl_filepath + ".meta" + meta_fp = jsonl_filepath + META_SUFFIX # save the generated meta information with open(meta_fp, "wb") as f: meta = np.array(meta, dtype=np.int32) @@ -162,6 +164,8 @@ def update_meta_without_tokenization(data_dir: str): if __name__ == "__main__": tokenize_jsonl() + + # # uncomment and run: srun -p MoE -c 16 python -m smoe.utils.tokenize # update_meta_without_tokenization( # "/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed" # ) diff --git a/smoe/utils/vars.py b/smoe/utils/vars.py index b68c481..304b047 100644 --- a/smoe/utils/vars.py +++ b/smoe/utils/vars.py @@ -3,3 +3,4 @@ MIDDLE_MODEL_CKPT_DIR = "middle" CLUSTERING_MODEL_NAME = "clustering.model" JSONL_DATASET_CACHE_NAME = "jsonl_dataset-{}.bin" +META_SUFFIX = ".meta" diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 72b897d..26722f6 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -111,29 +111,37 @@ def test_weighted_streaming_loader(): "en_arxiv": 0.025, "en_stack": 0.02, } + num_test_case = 2000 + block_size = 2048 + bsz = 1 + lm_datasets = SubDirWeightedPackedJsonlDataset( - "/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed", + "/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed", prob_map=prob_map, seed=1227, - block_size=2048, + block_size=block_size, ) - num_test_case = 2000 - bsz = 8 loader = DataLoader( lm_datasets, batch_size=bsz, - num_workers=4, + num_workers=0, collate_fn=fault_tolerance_data_collator, pin_memory=True, ) - for batch in loader: + for batch_idx, batch in enumerate(loader): if num_test_case <= 0: break assert len(batch["input_ids"]) == bsz + assert sum(loader.dataset.consumed_tokens.values()) == bsz * block_size num_test_case -= 1 +def test_skip_tokens(): + pass + + if __name__ == "__main__": # test_jsonl_dataset() # test_subdir_weighted_pack_with_type() - test_weighted_streaming() + # test_weighted_streaming() + test_weighted_streaming_loader() From b607b0b91ac35157e806aa91588f04887220bc24 Mon Sep 17 00:00:00 2001 From: zhutong Date: Wed, 15 Nov 2023 21:55:22 +0800 Subject: [PATCH 2/4] update dynamic data selection strategies --- .gitignore | 1 + .vscode/launch.json | 2 +- .../cpt/dynamic_data_selection/baseline.sh | 164 ++ .../dynamic_data_selection/baseline_32gpus.sh | 164 ++ .../sheared_llama_112gpus.sh | 165 ++ .../sheared_llama_paper.sh | 165 ++ scripts/eval/ref_loss.sh | 97 + scripts/eval/ref_loss_random_split.sh | 96 + smoe/data/dynamic_selection.py | 130 + smoe/data/streaming.py | 24 +- .../analysis/scale_factor_simulation.py | 64 + smoe/entrypoint/cpt/cpt_fpt.py | 26 +- smoe/modules/moe/moe_calculators.py | 1 + smoe/trainer/llama_lr_scheduling.py | 54 +- smoe/utils/config.py | 6 + smoe/utils/convert_moe_to_dense.py | 2318 +++++++++++++++++ smoe/utils/debugging.py | 22 +- smoe/utils/logging.py | 2 +- smoe/utils/param_estimation.py | 34 +- tests/data/test_streaming.py | 26 +- 20 files changed, 3527 insertions(+), 34 deletions(-) create mode 100644 scripts/cpt/dynamic_data_selection/baseline.sh create mode 100644 scripts/cpt/dynamic_data_selection/baseline_32gpus.sh create mode 100644 scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh create mode 100644 scripts/cpt/dynamic_data_selection/sheared_llama_paper.sh create mode 100644 scripts/eval/ref_loss.sh create mode 100644 scripts/eval/ref_loss_random_split.sh create mode 100644 smoe/data/dynamic_selection.py create mode 100644 smoe/entrypoint/analysis/scale_factor_simulation.py create mode 100644 smoe/utils/convert_moe_to_dense.py diff --git a/.gitignore b/.gitignore index 88a955c..db46c66 100644 --- a/.gitignore +++ b/.gitignore @@ -177,6 +177,7 @@ results/random_16select4_moe/ results/llama2_7B_gradient_share_gate_load/ results/gate_loss.png results/scale_distribution/ +results/analysis_scale_factor/ smoe/utils/gpu_diag.py /visualization_change_13B/ /visualization_change_7B/ diff --git a/.vscode/launch.json b/.vscode/launch.json index c70d9a0..38c3d90 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,7 +9,7 @@ "type": "python", "request": "attach", "connect": { - "host": "SH-IDCA1404-10-140-54-50", + "host": "SH-IDCA1404-10-140-54-123", "port": 5678 }, "pathMappings": [ diff --git a/scripts/cpt/dynamic_data_selection/baseline.sh b/scripts/cpt/dynamic_data_selection/baseline.sh new file mode 100644 index 0000000..c075bbd --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/baseline.sh @@ -0,0 +1,164 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_scale4_16gpus +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_16gpus/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_16gpus/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=2 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=2 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16, per-device bsz 4M tokens, lr 1e-4, 16gpus" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=1e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=8 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="1*10^9" + # warmup_tokens="0" + eval_tokens="1*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / $block_size" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + warmup_steps=$(echo "$warmup_tokens / $tokens_per_batch" | bc) + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + eval_steps=$(echo "$eval_tokens / $tokens_per_batch" | bc) + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_16gpus" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo $comment > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/scripts/cpt/dynamic_data_selection/baseline_32gpus.sh b/scripts/cpt/dynamic_data_selection/baseline_32gpus.sh new file mode 100644 index 0000000..9d144ac --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/baseline_32gpus.sh @@ -0,0 +1,164 @@ +#!/usr/bin/bash + +#SBATCH --job-name=llama2_random_scale4_32gpus +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=4 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=4 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16, per-device bsz 4M tokens, lr 1e-4, 16gpus" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=1e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=8 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="1*10^9" + # warmup_tokens="0" + eval_tokens="1*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / $block_size" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + warmup_steps=$(echo "$warmup_tokens / $tokens_per_batch" | bc) + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + eval_steps=$(echo "$eval_tokens / $tokens_per_batch" | bc) + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo $comment > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh new file mode 100644 index 0000000..b364ee6 --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh @@ -0,0 +1,165 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_scale4_112gpus +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16, per-device bsz 4M tokens, lr 1e-4, 16gpua" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=1e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=8 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="1*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / $block_size" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + warmup_steps=$(echo "$warmup_tokens / $tokens_per_batch" | bc) + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + eval_steps=$(echo "$eval_tokens / $tokens_per_batch" | bc) + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo $comment > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --dynamic_data_selection "sheared_llama" \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/scripts/cpt/dynamic_data_selection/sheared_llama_paper.sh b/scripts/cpt/dynamic_data_selection/sheared_llama_paper.sh new file mode 100644 index 0000000..a976e2b --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/sheared_llama_paper.sh @@ -0,0 +1,165 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_scale4_16gpus +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_16gpus/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_16gpus/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=2 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=2 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16, per-device bsz 4M tokens, lr 1e-4, 16gpua" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=1e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=8 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="1*10^9" + # warmup_tokens="0" + eval_tokens="500*10^6" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / $block_size" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + warmup_steps=$(echo "$warmup_tokens / $tokens_per_batch" | bc) + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + eval_steps=$(echo "$eval_tokens / $tokens_per_batch" | bc) + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_16gpus" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo $comment > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --dynamic_data_selection "sheared_llama_paper" \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/scripts/eval/ref_loss.sh b/scripts/eval/ref_loss.sh new file mode 100644 index 0000000..9cb1d06 --- /dev/null +++ b/scripts/eval/ref_loss.sh @@ -0,0 +1,97 @@ +#!/usr/bin/bash + +#SBATCH --job-name=eval_ref_loss +#SBATCH --output=logs/%x-%j.log +#SBATCH --error=logs/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=0 + +#SBATCH --nodes=1 +#SBATCH --gres=gpu:1 +#SBATCH --quotatype=reserved + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=1 # should match with --nodes + num_gpu_per_node=1 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + model_type="llama" + comment="llama 2 7B evaluation" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + pretrained_model=/mnt/petrelfs/zhutong/smoe/outputs/random_split_scale4_112gpus_11900steps_dense + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + ############################################################## + + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + block_size=4096 + seed=1227 + + data_cache=resources/cache + base_dir="." + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo $comment > $output_dir/comment.txt + # echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + # ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + + python smoe/entrypoint/cpt/cpt_fpt.py \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_eval \ + --seed ${seed} \ + --bf16 \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none +} diff --git a/scripts/eval/ref_loss_random_split.sh b/scripts/eval/ref_loss_random_split.sh new file mode 100644 index 0000000..551c8d2 --- /dev/null +++ b/scripts/eval/ref_loss_random_split.sh @@ -0,0 +1,96 @@ +#!/usr/bin/bash + +#SBATCH --job-name=eval_ref_loss +#SBATCH --output=logs/%x-%j.log +#SBATCH --error=logs/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=0 + +#SBATCH --nodes=1 +#SBATCH --gres=gpu:1 +#SBATCH --quotatype=reserved + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=1 # should match with --nodes + num_gpu_per_node=1 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + model_type="llama_moe" + comment="llama 2 7B evaluation" + pretrained_model=/mnt/petrelfs/zhutong/smoe/outputs/random_split_scale4_112gpus_11900steps + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + ############################################################## + + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + block_size=4096 + seed=1227 + + data_cache=resources/cache + base_dir="." + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo $comment > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + + python smoe/entrypoint/cpt/cpt_fpt.py \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_eval \ + --seed ${seed} \ + --bf16 \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none +} diff --git a/smoe/data/dynamic_selection.py b/smoe/data/dynamic_selection.py new file mode 100644 index 0000000..f9049c2 --- /dev/null +++ b/smoe/data/dynamic_selection.py @@ -0,0 +1,130 @@ +import numpy as np +import torch + +LLAMA_DATA_PORTION = { + "en_cc": 0.67, + "en_c4": 0.15, + "github": 0.045, + "en_wikipedia": 0.045, + "en_book": 0.045, + "en_arxiv": 0.025, + "en_stack": 0.02, +} + +LLAMA_DATA_PORTION_AVG = { + "en_cc": 1 / 7, + "en_c4": 1 / 7, + "github": 1 / 7, + "en_wikipedia": 1 / 7, + "en_book": 1 / 7, + "en_arxiv": 1 / 7, + "en_stack": 1 / 7, +} + +LLAMA2_7B_SLIMPAJAMA_VAL_REF_LOSS = { + "en_book": 1.925248146057129, + "en_wikipedia": 1.5899001359939575, + "en_stack": 1.4974864721298218, + "github": 0.6984495520591736, + "en_c4": 2.074881076812744, + "en_cc": 1.6916865110397339, + "en_arxiv": 1.2408167123794556, +} + + +""" +llama2-7B + +hellaswag: 2.664067268371582 +mmlu: 2.3618555068969727 +arc_challenge: 3.6212270259857178 +gsm8k: 1.280044436454773 +""" + + +def update_weight_sheared_llama_paper( + prob_map: dict[str, float], ref_loss: dict[str, float], curr_loss: dict[str, float] +) -> dict[str, float]: + """ + Args: + prob_map: dataset name -> prob + ref_loss: dataset name -> ref loss + curr_loss: dataset name -> curr loss + + Returns: + prob_map: updated prob map + + References: + Dynamic Batch Loading in ShearedLlama (http://arxiv.org/abs/2310.06694) + """ + task_types = [k for k in prob_map] + original_weight = np.array([prob_map[k] for k in task_types]) + loss_delta = np.array([max(0, curr_loss[k] - ref_loss[k]) for k in task_types]) + + # original method + alpha = original_weight * np.exp(loss_delta) + alpha /= alpha.sum() + + # method 2 + # ref_loss_arr = np.array([ref_loss[k] for k in task_types]) + # alpha = original_weight * np.exp(loss_delta / ref_loss_arr) + # alpha /= alpha.sum() + + # method 3 + # curr_loss_arr = np.array([curr_loss[k] for k in task_types]) + # ref_loss_arr = np.array([ref_loss[k] for k in task_types]) + # loss_delta_arr = curr_loss_arr - ref_loss_arr + # # loss_delta_arr /= ref_loss_arr + # lr = 1.0 + # alpha = original_weight + lr * loss_delta_arr + + return {k: v for k, v in zip(task_types, alpha)} + + +def update_weight_sheared_llama( + prob_map: dict[str, float], ref_loss: dict[str, float], curr_loss: dict[str, float] +) -> dict[str, float]: + """ + Args: + prob_map: dataset name -> prob + ref_loss: dataset name -> ref loss + curr_loss: dataset name -> curr loss + + Returns: + prob_map: updated prob map + + References: + Dynamic Batch Loading in ShearedLlama (http://arxiv.org/abs/2310.06694) + """ + task_types = [k for k in prob_map] + original_weight = torch.tensor([prob_map[k] for k in task_types]) + diff = torch.tensor([curr_loss[k] - ref_loss[k] for k in task_types]) + eta = 1.0 + c = 1e-4 + + updated_alpha = torch.log(original_weight) + eta * diff + updated_alpha = torch.nn.functional.softmax(updated_alpha, dim=0) + updated_domain_weights = (1 - c) * updated_alpha + c / len(task_types) + updated_domain_weights = updated_domain_weights.detach().numpy().astype("float64") + updated_domain_weights = updated_domain_weights / updated_domain_weights.sum() + updated_domain_weights = updated_domain_weights.tolist() + + return {k: v for k, v in zip(task_types, updated_domain_weights)} + + +if __name__ == "__main__": + # new_weight = update_weight_sheared_llama_paper( + new_weight = update_weight_sheared_llama( + LLAMA_DATA_PORTION, + LLAMA2_7B_SLIMPAJAMA_VAL_REF_LOSS, + { + "en_book": 2.071, + "en_wikipedia": 1.572, + "en_stack": 1.491, + "github": 0.705, + "en_c4": 2.117, + "en_cc": 1.728, + "en_arxiv": 1.287, + }, + ) + print(new_weight) diff --git a/smoe/data/streaming.py b/smoe/data/streaming.py index 689779e..5020ae0 100644 --- a/smoe/data/streaming.py +++ b/smoe/data/streaming.py @@ -12,7 +12,7 @@ import numpy as np import torch -from torch.utils.data import Dataset, IterableDataset +from torch.utils.data import Dataset, IterableDataset, get_worker_info from smoe.data.aggregation import group_instances from smoe.utils.io import load_jsonlines, load_jsonlines_iter @@ -356,7 +356,7 @@ class SubDirWeightedPackedJsonlDataset(IterableDataset): def __init__( self, dataset_dir: str, - prob_map: dict[str, float] = None, + prob_map: dict[str, float] | list[tuple[str, int]] = None, seed: int = 1227, buffer_size: int = 200, block_size: int = 2048, @@ -379,7 +379,17 @@ def __init__( logger.warning( f"Task type {task_type} not found in dataset dir. Skip it." ) - self.prob_map = prob_map + self.source2idx = {} + self.prob_map = {} + if isinstance(prob_map, dict): + _prob_map = list(prob_map.items()) + elif isinstance(prob_map, list): + _prob_map = prob_map + else: + raise ValueError(f"Unknown prob_map type: {type(prob_map)}") + for task_type, sampling_weight in _prob_map: + self.source2idx[task_type] = len(self.source2idx) + self.prob_map[task_type] = sampling_weight self.consumed_tokens = skip_tokens self.task_type_to_dataset = {} @@ -412,6 +422,14 @@ def skip_tokens(self, skip_tokens: dict): self.consumed_tokens[task_type] = 0 self.consumed_tokens[task_type] += num_skip_tokens + def update_prob_map(self, new_prob_map: dict): + self.prob_map.update(new_prob_map) + + def update_existed_prob_map(self, new_prob_map: dict): + for name in self.prob_map: + if name in new_prob_map: + self.prob_map[name] = new_prob_map[name] + def __iter__(self) -> Iterator: while len(self.task_type_to_dataset) > 0: candidate_task_types = list(self.task_type_to_dataset.keys()) diff --git a/smoe/entrypoint/analysis/scale_factor_simulation.py b/smoe/entrypoint/analysis/scale_factor_simulation.py new file mode 100644 index 0000000..c5e5512 --- /dev/null +++ b/smoe/entrypoint/analysis/scale_factor_simulation.py @@ -0,0 +1,64 @@ +import statistics as sts +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +xs = [] +ys = [] + + +def kaiming_init(*size): + tensor = torch.randn(*size) + nn.init.kaiming_normal_(tensor, mode="fan_out") + return tensor + + +init_func = torch.randn +# init_func = kaiming_init + +max_num_experts = 32 +intermediate_size = 11008 +hidden_size = 4096 +base = None +for k in range(1, max_num_experts + 1): + mid = int(intermediate_size / k) + distances = [] + for _ in range(10): + gate = init_func(hidden_size, mid) + up = init_func(hidden_size, mid) + down = init_func(mid, hidden_size) + + x = init_func(1, hidden_size) + # y = x @ l1 @ l2 + y = (F.silu(x @ gate) * (x @ up)) @ down + + dist = (x - y).abs().sum() + # dist = (x - y).pow(2).sum() + distances.append(dist.item()) + + xs.append(k) + if base is None and k == 1: + base = sts.mean(distances) + ys.append(base / sts.mean(distances)) + print(xs[-1], ys[-1]) + + +plt.plot(xs, ys, label="simulation") +plt.plot(xs, np.sqrt(xs), label="sqrt", linestyle="dashed") +plt.legend() +plt.xlabel("#Experts") +plt.ylabel("Scale Factor") +plt.grid(True, zorder=-1) + +# plt.title("SwiGLU Kaiming Normal Initialization (fan_out)") +# plt.savefig("swiglu_kaiming_fan_out_1024.png") + +out_dir = Path("results/analysis_scale_factor") +out_dir.mkdir(exist_ok=True, parents=True) +plt.title("Normal Initialization") +plt.savefig(out_dir / "normal.png") +# plt.savefig(out_dir / "normal_dropout_rescale.png") diff --git a/smoe/entrypoint/cpt/cpt_fpt.py b/smoe/entrypoint/cpt/cpt_fpt.py index ef594fa..1ad0e6d 100644 --- a/smoe/entrypoint/cpt/cpt_fpt.py +++ b/smoe/entrypoint/cpt/cpt_fpt.py @@ -60,7 +60,7 @@ logger = logging.getLogger(__name__) -@wechat_sender() +# @wechat_sender() def main(): model_args, data_args, training_args = parse_args( ModelArguments, DataArguments, EnhancedTrainingArguments @@ -341,11 +341,12 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - # update config for checkpoint retrival - # model.set_moe_gate_balance_loss_weight(0.1) - model.set_moe_calculator_score_scale_factor(4.0) - # model.set_moe_calculator_score_scale_factor(1.0) - model.update_config() + if hasattr(model, "set_moe_calculator_score_scale_factor"): + # update config for checkpoint retrival + # model.set_moe_gate_balance_loss_weight(0.1) + model.set_moe_calculator_score_scale_factor(4.0) + # model.set_moe_calculator_score_scale_factor(1.0) + model.update_config() model_vocab_size = model.get_output_embeddings().weight.size(0) if model_vocab_size != len(tokenizer): @@ -393,7 +394,18 @@ def make_inputs_require_grad(module, input, output): # Evaluation if training_args.do_eval: - raise NotImplementedError + if isinstance(trainer.eval_dataset, dict): + metrics = {} + for eval_dataset_name, eval_dataset in trainer.eval_dataset.items(): + dataset_metrics = trainer.evaluate( + eval_dataset=eval_dataset, + ignore_keys=None, + metric_key_prefix=f"eval_{eval_dataset_name}", + ) + metrics.update(dataset_metrics) + else: + metrics = trainer.evaluate(ignore_keys=None) + logger.info(f"{metrics}") if __name__ == "__main__": diff --git a/smoe/modules/moe/moe_calculators.py b/smoe/modules/moe/moe_calculators.py index d7a7499..ec6e998 100644 --- a/smoe/modules/moe/moe_calculators.py +++ b/smoe/modules/moe/moe_calculators.py @@ -127,6 +127,7 @@ def forward( if self.multiply_gate_scores: if self.mlp_norm is None: cat_expert_outputs = torch.mul(cat_expert_outputs, sorted_topK_scores.reshape(-1, 1) * self.score_scale_factor) # 乘权重 + # cat_expert_outputs = torch.mul(cat_expert_outputs, sorted_topK_scores.reshape(-1, 1) * 1.0) # 乘权重 else: cat_expert_outputs = torch.mul(cat_expert_outputs, sorted_topK_scores.reshape(-1, 1)) # 乘权重 cat_expert_outputs = self.mlp_norm(cat_expert_outputs) diff --git a/smoe/trainer/llama_lr_scheduling.py b/smoe/trainer/llama_lr_scheduling.py index b134857..95c3a76 100644 --- a/smoe/trainer/llama_lr_scheduling.py +++ b/smoe/trainer/llama_lr_scheduling.py @@ -51,6 +51,11 @@ logging, ) +from smoe.data.dynamic_selection import ( + LLAMA2_7B_SLIMPAJAMA_VAL_REF_LOSS, + update_weight_sheared_llama, + update_weight_sheared_llama_paper, +) from smoe.utils.config import EnhancedTrainingArguments if is_apex_available(): @@ -139,6 +144,7 @@ class EnhancedTrainerState(TrainerState): # last Token/GPU/second timestamp start_timestamp: float = 0.0 consumed_tokens: dict = field(default_factory=dict) + tot_consumed_tokens: int = 0 class LlamaLrSchedulingTrainer(Trainer): @@ -336,7 +342,7 @@ def _maybe_log_save_evaluate( x.detach().cpu().tolist() for x in gate_importance ] logs["balance_loss"] = balance_loss.item() - logs["consumed_tokens"] = self.state.consumed_tokens + logs["consumed_tokens"] = self.state.tot_consumed_tokens self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step @@ -368,6 +374,33 @@ def _maybe_log_save_evaluate( metric_to_check = f"eval_{metric_to_check}" self.lr_scheduler.step(metrics[metric_to_check]) + if ( + isinstance(self.args.dynamic_data_selection, str) + and self.args.dynamic_data_selection != "none" + and metrics is not None + and self.train_dataset is not None + ): + curr_loss_map = {} + for key, value in metrics.items(): + sobj = re.search(r"eval_(.*?)_loss", key) + if sobj: + dataset_name = sobj.group(1) + curr_loss_map[dataset_name] = value + new_prob_map = self.train_dataset.prob_map + if self.args.dynamic_data_selection == "sheared_llama_paper": + new_prob_map = update_weight_sheared_llama_paper( + self.train_dataset.prob_map, + LLAMA2_7B_SLIMPAJAMA_VAL_REF_LOSS, + curr_loss_map, + ) + elif self.args.dynamic_data_selection == "sheared_llama": + new_prob_map = update_weight_sheared_llama( + self.train_dataset.prob_map, + LLAMA2_7B_SLIMPAJAMA_VAL_REF_LOSS, + curr_loss_map, + ) + self.train_dataset.update_existed_prob_map(new_prob_map) + if self.control.should_save: self._save_checkpoint(model, trial, metrics=metrics) # zhutong: lr_scheduler is passed as an arg in `transformers.trainer_callback/CallbackHandler.call_event()` @@ -437,13 +470,13 @@ def get_train_dataloader(self) -> DataLoader: train_dataset = self.train_dataset data_collator = self.data_collator - # zhutong: update consumed_tokens - state_ctokens = sum(self.state.consumed_tokens.values()) - dataset_ctokens = sum(train_dataset.consumed_tokens.values()) - if state_ctokens > dataset_ctokens: - train_dataset.skip_tokens(state_ctokens) - # bind dataset recordings to state - self.state.consumed_tokens = train_dataset.consumed_tokens + # # zhutong: update consumed_tokens + # state_ctokens = sum(self.state.consumed_tokens.values()) + # dataset_ctokens = sum(train_dataset.consumed_tokens.values()) + # if state_ctokens > dataset_ctokens: + # train_dataset.skip_tokens(state_ctokens) + # # bind dataset recordings to state + # self.state.consumed_tokens = train_dataset.consumed_tokens if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): train_dataset = self._remove_unused_columns( @@ -712,7 +745,7 @@ def _inner_training_loop( logger.info( f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch." - f" Consumed tokens: {self.state.consumed_tokens}" + f" Consumed tokens: {self.state.tot_consumed_tokens}" ) # Update the references @@ -922,7 +955,8 @@ def _inner_training_loop( model.zero_grad() self.state.global_step += 1 - self.state.consumed_tokens = self.train_dataset.consumed_tokens + # self.state.consumed_tokens = self.train_dataset.consumed_tokens + self.state.tot_consumed_tokens += self.args.num_tokens_per_batch self.state.epoch = ( epoch + (step + 1 + steps_skipped) / steps_in_epoch ) diff --git a/smoe/utils/config.py b/smoe/utils/config.py index 76642ed..525c89e 100644 --- a/smoe/utils/config.py +++ b/smoe/utils/config.py @@ -313,6 +313,12 @@ class EnhancedTrainingArguments(TrainingArguments): "help": "The number of model parameters used for training. If set to -1, it will be calculated automatically." }, ) + dynamic_data_selection: Optional[str] = field( + default="none", + metadata={ + "help": "dynamic data selection strategy (change data portion dynamically based on current loss and reference loss)." + }, + ) @property def block_size(self): diff --git a/smoe/utils/convert_moe_to_dense.py b/smoe/utils/convert_moe_to_dense.py new file mode 100644 index 0000000..4b618b3 --- /dev/null +++ b/smoe/utils/convert_moe_to_dense.py @@ -0,0 +1,2318 @@ +""" +Convert MoEfication models back to dense models. +""" +import json +import re +import shutil +from pathlib import Path + +import torch +from transformers.modeling_utils import dtype_byte_size +from transformers.utils import WEIGHTS_INDEX_NAME + +# from transformers.models.llama import LlamaForCausalLM + +# from smoe.models.llama_moe import LlamaMoEForCausalLM + + +def get_layer_nums(keys): + layers = [] + for key in keys: + s = re.search(r"model\.layers\.(\d+)\.", key) + if s: + layers.append(int(s.group(1))) + layers.sort() + return layers + + +def get_num_experts(keys): + num_experts = 0 + for key in keys: + s = re.search( + r"model\.layers\.\d+\.mlp\.calculator\.experts\.weight_up\.(\d+)", key + ) + if s: + num_experts = max(num_experts, int(s.group(1))) + num_experts += 1 + return num_experts + + +def main(): + dense_model_dir = "/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B/" + moe_model_dir = "/mnt/petrelfs/zhutong/smoe/outputs/random_split_scale4_112gpus_11900steps/checkpoint-11900/" + out_dense_model_dir = "/mnt/petrelfs/zhutong/smoe/outputs/random_split_scale4_112gpus_11900steps_dense" + filenames = ["pytorch_model-00001-of-00002.bin", "pytorch_model-00002-of-00002.bin"] + + out_dir = Path(out_dense_model_dir) + out_dir.mkdir(exist_ok=True, parents=True) + total_size = 0 + weight_map = {} + for name in filenames: + dense = {} + moe = torch.load(Path(moe_model_dir) / name, map_location="cpu") + for key in ["model.embed_tokens.weight", "model.norm.weight", "lm_head.weight"]: + if key in moe: + dense[key] = moe[key] + layers = get_layer_nums(moe.keys()) + num_experts = get_num_experts(moe.keys()) + for layer_idx in layers: + for key in ["q_proj", "k_proj", "v_proj", "o_proj"]: + dense[f"model.layers.{layer_idx}.self_attn.{key}.weight"] = moe[ + f"model.layers.{layer_idx}.self_attn.{key}.weight" + ] + dense[f"model.layers.{layer_idx}.self_attn.rotary_emb.inv_freq"] = moe[ + f"model.layers.{layer_idx}.self_attn.rotary_emb.inv_freq" + ] + dense[f"model.layers.{layer_idx}.input_layernorm.weight"] = moe[ + f"model.layers.{layer_idx}.input_layernorm.weight" + ] + dense[f"model.layers.{layer_idx}.post_attention_layernorm.weight"] = moe[ + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ] + + up_proj = [] + for expert_idx in range(num_experts): + param = moe[ + f"model.layers.{layer_idx}.mlp.calculator.experts.weight_up.{expert_idx}" + ] + up_proj.append(param) + up_proj_cat = torch.cat(up_proj, dim=0) + dense[f"model.layers.{layer_idx}.mlp.up_proj.weight"] = up_proj_cat + + gate_proj = [] + for expert_idx in range(num_experts): + param = moe[ + f"model.layers.{layer_idx}.mlp.calculator.experts.weight_gate.{expert_idx}" + ] + gate_proj.append(param) + gate_proj_cat = torch.cat(gate_proj, dim=0) + dense[f"model.layers.{layer_idx}.mlp.gate_proj.weight"] = gate_proj_cat + + down_proj = [] + for expert_idx in range(num_experts): + param = moe[ + f"model.layers.{layer_idx}.mlp.calculator.experts.weight_down.{expert_idx}" + ] + down_proj.append(param) + down_proj_cat = torch.cat(down_proj, dim=1) + dense[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = down_proj_cat + + torch.save(dense, out_dir / name) + + for p_key, param in dense.items(): + total_size += param.numel() * dtype_byte_size(param.dtype) + weight_map[p_key] = name + + index_map = {"metadata": {"total_size": total_size}, "weight_map": weight_map} + with out_dir.joinpath(WEIGHTS_INDEX_NAME).open("w", encoding="utf8") as fout: + json.dump(index_map, fout, indent=2, ensure_ascii=False) + + # copy files from original dense model to target model + filenames = [ + "config.json", + "generation_config.json", + "special_tokens_map.json", + "tokenizer_config.json", + "tokenizer.model", + ] + for filename in filenames: + shutil.copy(Path(dense_model_dir) / filename, out_dir / filename) + + +if __name__ == "__main__": + main() + # x = torch.load("outputs/random_split_scale4_112gpus_11900steps_dense/pytorch_model-00001-of-00002.bin") + # for name, param in x.items(): + # print(name, param.shape) + + +""" +dense + +model.embed_tokens.weight torch.Size([32000, 4096]) +model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.0.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.0.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.0.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.0.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.0.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.0.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.0.input_layernorm.weight torch.Size([4096]) +model.layers.0.post_attention_layernorm.weight torch.Size([4096]) +model.layers.1.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.1.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.1.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.1.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.1.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.1.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.1.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.1.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.1.input_layernorm.weight torch.Size([4096]) +model.layers.1.post_attention_layernorm.weight torch.Size([4096]) +model.layers.2.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.2.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.2.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.2.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.2.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.2.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.2.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.2.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.2.input_layernorm.weight torch.Size([4096]) +model.layers.2.post_attention_layernorm.weight torch.Size([4096]) +model.layers.3.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.3.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.3.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.3.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.3.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.3.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.3.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.3.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.3.input_layernorm.weight torch.Size([4096]) +model.layers.3.post_attention_layernorm.weight torch.Size([4096]) +model.layers.4.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.4.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.4.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.4.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.4.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.4.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.4.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.4.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.4.input_layernorm.weight torch.Size([4096]) +model.layers.4.post_attention_layernorm.weight torch.Size([4096]) +model.layers.5.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.5.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.5.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.5.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.5.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.5.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.5.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.5.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.5.input_layernorm.weight torch.Size([4096]) +model.layers.5.post_attention_layernorm.weight torch.Size([4096]) +model.layers.6.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.6.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.6.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.6.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.6.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.6.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.6.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.6.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.6.input_layernorm.weight torch.Size([4096]) +model.layers.6.post_attention_layernorm.weight torch.Size([4096]) +model.layers.7.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.7.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.7.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.7.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.7.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.7.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.7.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.7.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.7.input_layernorm.weight torch.Size([4096]) +model.layers.7.post_attention_layernorm.weight torch.Size([4096]) +model.layers.8.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.8.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.8.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.8.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.8.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.8.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.8.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.8.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.8.input_layernorm.weight torch.Size([4096]) +model.layers.8.post_attention_layernorm.weight torch.Size([4096]) +model.layers.9.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.9.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.9.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.9.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.9.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.9.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.9.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.9.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.9.input_layernorm.weight torch.Size([4096]) +model.layers.9.post_attention_layernorm.weight torch.Size([4096]) +model.layers.10.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.10.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.10.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.10.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.10.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.10.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.10.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.10.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.10.input_layernorm.weight torch.Size([4096]) +model.layers.10.post_attention_layernorm.weight torch.Size([4096]) +model.layers.11.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.11.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.11.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.11.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.11.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.11.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.11.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.11.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.11.input_layernorm.weight torch.Size([4096]) +model.layers.11.post_attention_layernorm.weight torch.Size([4096]) +model.layers.12.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.12.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.12.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.12.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.12.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.12.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.12.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.12.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.12.input_layernorm.weight torch.Size([4096]) +model.layers.12.post_attention_layernorm.weight torch.Size([4096]) +model.layers.13.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.13.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.13.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.13.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.13.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.13.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.13.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.13.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.13.input_layernorm.weight torch.Size([4096]) +model.layers.13.post_attention_layernorm.weight torch.Size([4096]) +model.layers.14.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.14.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.14.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.14.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.14.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.14.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.14.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.14.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.14.input_layernorm.weight torch.Size([4096]) +model.layers.14.post_attention_layernorm.weight torch.Size([4096]) +model.layers.15.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.15.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.15.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.15.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.15.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.15.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.15.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.15.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.15.input_layernorm.weight torch.Size([4096]) +model.layers.15.post_attention_layernorm.weight torch.Size([4096]) +model.layers.16.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.16.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.16.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.16.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.16.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.16.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.16.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.16.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.16.input_layernorm.weight torch.Size([4096]) +model.layers.16.post_attention_layernorm.weight torch.Size([4096]) +model.layers.17.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.17.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.17.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.17.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.17.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.17.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.17.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.17.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.17.input_layernorm.weight torch.Size([4096]) +model.layers.17.post_attention_layernorm.weight torch.Size([4096]) +model.layers.18.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.18.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.18.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.18.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.18.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.18.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.18.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.18.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.18.input_layernorm.weight torch.Size([4096]) +model.layers.18.post_attention_layernorm.weight torch.Size([4096]) +model.layers.19.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.19.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.19.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.19.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.19.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.19.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.19.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.19.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.19.input_layernorm.weight torch.Size([4096]) +model.layers.19.post_attention_layernorm.weight torch.Size([4096]) +model.layers.20.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.20.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.20.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.20.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.20.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.20.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.20.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.20.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.20.input_layernorm.weight torch.Size([4096]) +model.layers.20.post_attention_layernorm.weight torch.Size([4096]) +model.layers.21.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.21.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.21.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.21.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.21.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.21.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.21.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.21.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.21.input_layernorm.weight torch.Size([4096]) +model.layers.21.post_attention_layernorm.weight torch.Size([4096]) +model.layers.22.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.22.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.22.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.22.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.22.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.22.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.22.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.22.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.22.input_layernorm.weight torch.Size([4096]) +model.layers.22.post_attention_layernorm.weight torch.Size([4096]) +model.layers.23.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.23.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.23.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.23.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.23.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.23.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.23.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.23.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.23.input_layernorm.weight torch.Size([4096]) +model.layers.23.post_attention_layernorm.weight torch.Size([4096]) +model.layers.24.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.24.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.24.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.24.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.24.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.24.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.24.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.24.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.24.input_layernorm.weight torch.Size([4096]) +model.layers.24.post_attention_layernorm.weight torch.Size([4096]) +model.layers.25.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.25.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.25.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.25.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.25.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.25.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.25.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.25.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.25.input_layernorm.weight torch.Size([4096]) +model.layers.25.post_attention_layernorm.weight torch.Size([4096]) +model.layers.26.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.26.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.26.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.26.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.26.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.26.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.26.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.26.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.26.input_layernorm.weight torch.Size([4096]) +model.layers.26.post_attention_layernorm.weight torch.Size([4096]) +model.layers.27.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.27.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.27.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.27.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.27.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.27.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.27.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.27.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.27.input_layernorm.weight torch.Size([4096]) +model.layers.27.post_attention_layernorm.weight torch.Size([4096]) +model.layers.28.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.28.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.28.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.28.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.28.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.28.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.28.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.28.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.28.input_layernorm.weight torch.Size([4096]) +model.layers.28.post_attention_layernorm.weight torch.Size([4096]) +model.layers.29.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.29.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.29.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.29.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.29.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.29.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.29.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.29.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.29.input_layernorm.weight torch.Size([4096]) +model.layers.29.post_attention_layernorm.weight torch.Size([4096]) +model.layers.30.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.30.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.30.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.30.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.30.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.30.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.30.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.30.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.30.input_layernorm.weight torch.Size([4096]) +model.layers.30.post_attention_layernorm.weight torch.Size([4096]) +model.layers.31.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.31.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.31.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.31.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.31.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.31.mlp.gate_proj.weight torch.Size([11008, 4096]) +model.layers.31.mlp.up_proj.weight torch.Size([11008, 4096]) +model.layers.31.mlp.down_proj.weight torch.Size([4096, 11008]) +model.layers.31.input_layernorm.weight torch.Size([4096]) +model.layers.31.post_attention_layernorm.weight torch.Size([4096]) +model.norm.weight torch.Size([4096]) +lm_head.weight torch.Size([32000, 4096]) + +=================================================================== +moe + +model.embed_tokens.weight torch.Size([32000, 4096]) +model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.0.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.0.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.0.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.0.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.0.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.0.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.0.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.0.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.0.input_layernorm.weight torch.Size([4096]) +model.layers.0.post_attention_layernorm.weight torch.Size([4096]) +model.layers.1.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.1.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.1.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.1.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.1.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.1.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.1.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.1.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.1.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.1.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.1.input_layernorm.weight torch.Size([4096]) +model.layers.1.post_attention_layernorm.weight torch.Size([4096]) +model.layers.2.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.2.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.2.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.2.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.2.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.2.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.2.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.2.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.2.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.2.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.2.input_layernorm.weight torch.Size([4096]) +model.layers.2.post_attention_layernorm.weight torch.Size([4096]) +model.layers.3.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.3.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.3.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.3.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.3.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.3.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.3.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.3.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.3.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.3.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.3.input_layernorm.weight torch.Size([4096]) +model.layers.3.post_attention_layernorm.weight torch.Size([4096]) +model.layers.4.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.4.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.4.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.4.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.4.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.4.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.4.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.4.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.4.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.4.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.4.input_layernorm.weight torch.Size([4096]) +model.layers.4.post_attention_layernorm.weight torch.Size([4096]) +model.layers.5.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.5.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.5.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.5.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.5.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.5.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.5.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.5.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.5.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.5.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.5.input_layernorm.weight torch.Size([4096]) +model.layers.5.post_attention_layernorm.weight torch.Size([4096]) +model.layers.6.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.6.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.6.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.6.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.6.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.6.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.6.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.6.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.6.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.6.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.6.input_layernorm.weight torch.Size([4096]) +model.layers.6.post_attention_layernorm.weight torch.Size([4096]) +model.layers.7.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.7.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.7.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.7.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.7.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.7.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.7.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.7.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.7.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.7.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.7.input_layernorm.weight torch.Size([4096]) +model.layers.7.post_attention_layernorm.weight torch.Size([4096]) +model.layers.8.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.8.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.8.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.8.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.8.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.8.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.8.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.8.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.8.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.8.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.8.input_layernorm.weight torch.Size([4096]) +model.layers.8.post_attention_layernorm.weight torch.Size([4096]) +model.layers.9.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.9.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.9.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.9.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.9.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.9.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.9.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.9.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.9.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.9.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.9.input_layernorm.weight torch.Size([4096]) +model.layers.9.post_attention_layernorm.weight torch.Size([4096]) +model.layers.10.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.10.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.10.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.10.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.10.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.10.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.10.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.10.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.10.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.10.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.10.input_layernorm.weight torch.Size([4096]) +model.layers.10.post_attention_layernorm.weight torch.Size([4096]) +model.layers.11.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.11.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.11.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.11.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.11.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.11.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.11.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.11.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.11.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.11.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.11.input_layernorm.weight torch.Size([4096]) +model.layers.11.post_attention_layernorm.weight torch.Size([4096]) +model.layers.12.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.12.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.12.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.12.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.12.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.12.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.12.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.12.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.12.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.12.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.12.input_layernorm.weight torch.Size([4096]) +model.layers.12.post_attention_layernorm.weight torch.Size([4096]) +model.layers.13.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.13.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.13.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.13.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.13.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.13.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.13.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.13.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.13.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.13.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.13.input_layernorm.weight torch.Size([4096]) +model.layers.13.post_attention_layernorm.weight torch.Size([4096]) +model.layers.14.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.14.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.14.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.14.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.14.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.14.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.14.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.14.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.14.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.14.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.14.input_layernorm.weight torch.Size([4096]) +model.layers.14.post_attention_layernorm.weight torch.Size([4096]) +model.layers.15.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.15.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.15.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.15.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.15.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.15.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.15.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.15.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.15.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.15.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.15.input_layernorm.weight torch.Size([4096]) +model.layers.15.post_attention_layernorm.weight torch.Size([4096]) +model.layers.16.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.16.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.16.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.16.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.16.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.16.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.16.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.16.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.16.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.16.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.16.input_layernorm.weight torch.Size([4096]) +model.layers.16.post_attention_layernorm.weight torch.Size([4096]) +model.layers.17.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.17.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.17.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.17.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.17.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.17.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.17.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.17.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.17.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.17.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.17.input_layernorm.weight torch.Size([4096]) +model.layers.17.post_attention_layernorm.weight torch.Size([4096]) +model.layers.18.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.18.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.18.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.18.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.18.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.18.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.18.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.18.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.18.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.18.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.18.input_layernorm.weight torch.Size([4096]) +model.layers.18.post_attention_layernorm.weight torch.Size([4096]) +model.layers.19.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.19.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.19.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.19.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.19.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.19.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.19.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.19.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.19.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.19.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.19.input_layernorm.weight torch.Size([4096]) +model.layers.19.post_attention_layernorm.weight torch.Size([4096]) +model.layers.20.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.20.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.20.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.20.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.20.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.20.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.20.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.20.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.20.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.20.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.20.input_layernorm.weight torch.Size([4096]) +model.layers.20.post_attention_layernorm.weight torch.Size([4096]) +model.layers.21.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.21.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.21.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.21.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.21.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.21.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.21.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.21.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.21.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.21.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.21.input_layernorm.weight torch.Size([4096]) +model.layers.21.post_attention_layernorm.weight torch.Size([4096]) +model.layers.22.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.22.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.22.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.22.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.22.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.22.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.22.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.22.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.22.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.22.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.22.input_layernorm.weight torch.Size([4096]) +model.layers.22.post_attention_layernorm.weight torch.Size([4096]) +model.layers.23.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.23.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.23.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.23.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.23.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.23.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.23.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.23.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.23.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.23.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.23.input_layernorm.weight torch.Size([4096]) +model.layers.23.post_attention_layernorm.weight torch.Size([4096]) +model.layers.24.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.24.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.24.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.24.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.24.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.24.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.24.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.24.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.24.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.24.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.24.input_layernorm.weight torch.Size([4096]) +model.layers.24.post_attention_layernorm.weight torch.Size([4096]) +model.layers.25.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.25.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.25.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.25.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.25.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.25.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.25.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.25.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.25.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.25.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.25.input_layernorm.weight torch.Size([4096]) +model.layers.25.post_attention_layernorm.weight torch.Size([4096]) +model.layers.26.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.26.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.26.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.26.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.26.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.26.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.26.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.26.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.26.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.26.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.26.input_layernorm.weight torch.Size([4096]) +model.layers.26.post_attention_layernorm.weight torch.Size([4096]) +model.layers.27.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.27.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.27.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.27.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.27.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.27.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.27.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.27.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.27.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.27.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.27.input_layernorm.weight torch.Size([4096]) +model.layers.27.post_attention_layernorm.weight torch.Size([4096]) +model.layers.28.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.28.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.28.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.28.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.28.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.28.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.28.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.28.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.28.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.28.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.28.input_layernorm.weight torch.Size([4096]) +model.layers.28.post_attention_layernorm.weight torch.Size([4096]) +model.layers.29.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.29.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.29.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.29.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.29.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.29.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.29.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.29.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.29.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.29.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.29.input_layernorm.weight torch.Size([4096]) +model.layers.29.post_attention_layernorm.weight torch.Size([4096]) +model.layers.30.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.30.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.30.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.30.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.30.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.30.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.30.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.30.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.30.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.30.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.30.input_layernorm.weight torch.Size([4096]) +model.layers.30.post_attention_layernorm.weight torch.Size([4096]) +model.layers.31.self_attn.q_proj.weight torch.Size([4096, 4096]) +model.layers.31.self_attn.k_proj.weight torch.Size([4096, 4096]) +model.layers.31.self_attn.v_proj.weight torch.Size([4096, 4096]) +model.layers.31.self_attn.o_proj.weight torch.Size([4096, 4096]) +model.layers.31.self_attn.rotary_emb.inv_freq torch.Size([64]) +model.layers.31.mlp.gate.gate_network.0.weight torch.Size([16, 4096]) +model.layers.31.mlp.gate.gate_network.2.weight torch.Size([16, 16]) +model.layers.31.mlp.gate.weight_noise.weight torch.Size([16, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.0 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.1 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.2 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.3 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.4 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.5 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.6 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.7 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.8 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.9 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.10 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.11 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.12 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.13 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.14 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_gate.15 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.0 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.1 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.2 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.3 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.4 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.5 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.6 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.7 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.8 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.9 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.10 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.11 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.12 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.13 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.14 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_up.15 torch.Size([688, 4096]) +model.layers.31.mlp.calculator.experts.weight_down.0 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.1 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.2 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.3 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.4 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.5 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.6 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.7 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.8 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.9 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.10 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.11 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.12 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.13 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.14 torch.Size([4096, 688]) +model.layers.31.mlp.calculator.experts.weight_down.15 torch.Size([4096, 688]) +model.layers.31.input_layernorm.weight torch.Size([4096]) +model.layers.31.post_attention_layernorm.weight torch.Size([4096]) +model.norm.weight torch.Size([4096]) +lm_head.weight torch.Size([32000, 4096]) +""" diff --git a/smoe/utils/debugging.py b/smoe/utils/debugging.py index 366be43..a80e695 100644 --- a/smoe/utils/debugging.py +++ b/smoe/utils/debugging.py @@ -1,8 +1,10 @@ +import socket + import debugpy import torch.distributed as dist -def remote_breakpoint(host: str = "0.0.0.0", port: int = 5678): +def remote_breakpoint(host: str = "0.0.0.0", port: int = 5678, rank: int = 0): """ This function helps to debug programs running in the remote computing node. @@ -41,10 +43,18 @@ def remote_breakpoint(host: str = "0.0.0.0", port: int = 5678): After the program starts and encounters the breakpoint, you could remote attach the debugger. """ + + def _dp(): + print( + f"Waiting for debugger to attach on {host}:{port}, server: {socket.gethostname()}..." + ) + debugpy.listen((host, port)) + debugpy.wait_for_client() + breakpoint() + if dist.is_available() and dist.is_initialized(): - rank = dist.get_rank() - if rank == 0: - debugpy.listen((host, port)) - debugpy.wait_for_client() - breakpoint() + if dist.get_rank() == rank: + _dp() dist.barrier() + else: + _dp() diff --git a/smoe/utils/logging.py b/smoe/utils/logging.py index 90adecd..cdffdd8 100644 --- a/smoe/utils/logging.py +++ b/smoe/utils/logging.py @@ -7,7 +7,7 @@ # Setup logging logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + format="%(asctime)s - %(levelname)s - %(name)s - %(filename)s - %(funcName)s - %(processName)s(%(process)d)/%(threadName)s(%(thread)d) %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], diff --git a/smoe/utils/param_estimation.py b/smoe/utils/param_estimation.py index efe30e0..b3cfc60 100644 --- a/smoe/utils/param_estimation.py +++ b/smoe/utils/param_estimation.py @@ -21,11 +21,13 @@ def estimate_moe_param( post_attn_norm = hidden_size dense_one_layer = self_attn + mlp + input_norm + post_attn_norm - dense_params = emb + lm_head + final_norm + dense_one_layer * num_hidden_layers + dense_mid = dense_one_layer * num_hidden_layers + dense_params = emb + lm_head + final_norm + dense_mid gate = hidden_size * num_experts + num_experts * num_selects moe_one_layer = self_attn + mlp + input_norm + post_attn_norm + gate - moe_total_params = emb + lm_head + final_norm + moe_one_layer * num_hidden_layers + moe_mid = moe_one_layer * num_hidden_layers + moe_total_params = emb + lm_head + final_norm + moe_mid moe_one_act_layer = ( self_attn @@ -34,12 +36,16 @@ def estimate_moe_param( + post_attn_norm + gate ) - moe_act_params = emb + lm_head + final_norm + moe_one_act_layer * num_hidden_layers + moe_act_mid = moe_one_act_layer * num_hidden_layers + moe_act_params = emb + lm_head + final_norm + moe_act_mid return { "dense_params": dense_params, "moe_total_params": moe_total_params, "moe_act_params": moe_act_params, + "dense_mid": dense_mid, + "moe_mid": moe_mid, + "moe_act_mid": moe_act_mid, } @@ -62,3 +68,25 @@ def estimate_moe_param( 32000, 3200, 26, 8640 * num_experts, num_experts, 1 ) print(f"3B upcycling {num_experts} experts", res_3B_up) + + # ShearedLlama-1.3B upcycling + for num_experts in range(1, 17): + res_3B_up = estimate_moe_param( + 32000, 2048, 24, 5504 * num_experts, num_experts, 1 + ) + print(f"ShearedLlama-1.3B upcycling {num_experts} experts", res_3B_up) + + # ShearedLlama-2.7B upcycling + for num_experts in range(1, 17): + res_3B_up = estimate_moe_param( + 32000, 2560, 32, 6912 * num_experts, num_experts, 1 + ) + print(f"ShearedLlama-2.7B upcycling {num_experts} experts", res_3B_up) + + # 7B, moe half layers + res_7B_half = estimate_moe_param(32000, 4096, 8, 11008, 16, 2) + print("7B half 8 layers", res_7B_half) + res_7B_half = estimate_moe_param(32000, 4096, 24, 11008, 16, 2) + print("7B half 24 layers", res_7B_half) + res_7B_half = estimate_moe_param(32000, 4096, 16, 11008, 16, 1) + print("7B half 16 layers 1/16", res_7B_half) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 26722f6..a65b30d 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -102,6 +102,12 @@ def test_weighted_streaming(): def test_weighted_streaming_loader(): + # from datasets import IterableDataset + # ds = IterableDataset.from_generator() + from accelerate import Accelerator + + ac = Accelerator() + prob_map = { "en_cc": 0.67, "en_c4": 0.15, @@ -111,7 +117,7 @@ def test_weighted_streaming_loader(): "en_arxiv": 0.025, "en_stack": 0.02, } - num_test_case = 2000 + num_test_case = 2 block_size = 2048 bsz = 1 @@ -126,14 +132,28 @@ def test_weighted_streaming_loader(): batch_size=bsz, num_workers=0, collate_fn=fault_tolerance_data_collator, - pin_memory=True, + pin_memory=False, ) + loader = ac.prepare_data_loader(loader) + for batch_idx, batch in enumerate(loader): + if batch_idx == 0: + print(f"RANK {ac.process_index}/{ac.num_processes} - {batch}") if num_test_case <= 0: break assert len(batch["input_ids"]) == bsz - assert sum(loader.dataset.consumed_tokens.values()) == bsz * block_size + print( + f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" + ) + # assert sum(loader.dataset.consumed_tokens.values()) == (batch_idx + 1) * block_size + print(loader.dataset.prob_map) num_test_case -= 1 + lm_datasets.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) + # loader.dataset.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) + print(loader.dataset.prob_map) + print( + f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" + ) def test_skip_tokens(): From 073487fc1995eead5ddca8e454cd40fe4387a62a Mon Sep 17 00:00:00 2001 From: zhutong Date: Thu, 16 Nov 2023 10:27:09 +0800 Subject: [PATCH 3/4] support dynamic batch loading --- .gitignore | 13 +- .../cpt/dynamic_data_selection/baseline.sh | 1 - .../dynamic_data_selection/baseline_32gpus.sh | 13 +- .../sheared_llama_112gpus.sh | 17 +- .../sheared_llama_32gpus.sh | 164 ++++++++++++++++++ smoe/callbacks/tensorboard.py | 12 +- smoe/data/streaming.py | 127 ++++---------- smoe/entrypoint/cpt/cpt_fpt.py | 2 +- smoe/trainer/llama_lr_scheduling.py | 18 +- smoe/utils/logging.py | 2 +- tests/data/test_streaming.py | 12 +- 11 files changed, 242 insertions(+), 139 deletions(-) create mode 100644 scripts/cpt/dynamic_data_selection/sheared_llama_32gpus.sh diff --git a/.gitignore b/.gitignore index db46c66..387e2de 100644 --- a/.gitignore +++ b/.gitignore @@ -166,18 +166,7 @@ resources/ outputs/ /backups/ /visualization/ -results/analysis/cluster_*.png -results/expert_load_vis/ -results/analysis_clustering*/ -results/gate_loss_100b/ -results/RandomSplit-l2_norm-llama_7B-16Select4-up_proj/ -results/gate_loss_original_clustering_model/ -results/llama_7B_MoE_16Select4-l2_norm/ -results/random_16select4_moe/ -results/llama2_7B_gradient_share_gate_load/ -results/gate_loss.png -results/scale_distribution/ -results/analysis_scale_factor/ +results/ smoe/utils/gpu_diag.py /visualization_change_13B/ /visualization_change_7B/ diff --git a/scripts/cpt/dynamic_data_selection/baseline.sh b/scripts/cpt/dynamic_data_selection/baseline.sh index c075bbd..b260e10 100644 --- a/scripts/cpt/dynamic_data_selection/baseline.sh +++ b/scripts/cpt/dynamic_data_selection/baseline.sh @@ -122,7 +122,6 @@ source ~/anaconda3/bin/activate smoe --per_device_train_batch_size ${per_device_train_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ --do_train \ - --do_eval \ --evaluation_strategy steps \ --eval_steps ${eval_steps} \ --seed ${seed} \ diff --git a/scripts/cpt/dynamic_data_selection/baseline_32gpus.sh b/scripts/cpt/dynamic_data_selection/baseline_32gpus.sh index 9d144ac..7cd2458 100644 --- a/scripts/cpt/dynamic_data_selection/baseline_32gpus.sh +++ b/scripts/cpt/dynamic_data_selection/baseline_32gpus.sh @@ -1,8 +1,8 @@ #!/usr/bin/bash -#SBATCH --job-name=llama2_random_scale4_32gpus -#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus/%x-%j.log -#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus/%x-%j.log +#SBATCH --job-name=llama2_random_scale4_32gpus_dynamic_data +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus_dynamic_data/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus_dynamic_data/%x-%j.log #SBATCH --partition=MoE #SBATCH --ntasks-per-node=1 @@ -55,11 +55,11 @@ source ~/anaconda3/bin/activate smoe dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized - lr=1e-4 + lr=2e-4 final_lr_portion=0.1 per_device_train_batch_size=8 per_device_eval_batch_size=8 - gradient_accumulation_steps=8 + gradient_accumulation_steps=4 block_size=4096 num_tokens="200*10^9" warmup_tokens="1*10^9" @@ -84,7 +84,7 @@ source ~/anaconda3/bin/activate smoe echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" data_cache=resources/cache - base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus" + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus_dynamic_data" output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID mkdir -p $output_dir echo "output_dir: $output_dir" @@ -122,7 +122,6 @@ source ~/anaconda3/bin/activate smoe --per_device_train_batch_size ${per_device_train_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ --do_train \ - --do_eval \ --evaluation_strategy steps \ --eval_steps ${eval_steps} \ --seed ${seed} \ diff --git a/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh index b364ee6..14d3375 100644 --- a/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh +++ b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh @@ -1,8 +1,8 @@ #!/usr/bin/bash -#SBATCH --job-name=cpt-llama2_random_scale4_112gpus -#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus/%x-%j.log -#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus/%x-%j.log +#SBATCH --job-name=cpt-llama2_random_scale4_112gpus_dynamic_data +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log #SBATCH --partition=MoE #SBATCH --ntasks-per-node=1 @@ -40,7 +40,7 @@ source ~/anaconda3/bin/activate smoe # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" model_type="llama_moe" - comment="llama 2 7B, random 4/16, per-device bsz 4M tokens, lr 1e-4, 16gpua" + comment="llama 2 7B, random 4/16" pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" @@ -55,11 +55,11 @@ source ~/anaconda3/bin/activate smoe dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized - lr=1e-4 + lr=2e-4 final_lr_portion=0.1 per_device_train_batch_size=8 per_device_eval_batch_size=8 - gradient_accumulation_steps=8 + gradient_accumulation_steps=4 block_size=4096 num_tokens="200*10^9" warmup_tokens="15*10^8" @@ -84,7 +84,7 @@ source ~/anaconda3/bin/activate smoe echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" data_cache=resources/cache - base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus" + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data" output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID mkdir -p $output_dir echo "output_dir: $output_dir" @@ -123,7 +123,6 @@ source ~/anaconda3/bin/activate smoe --per_device_train_batch_size ${per_device_train_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ --do_train \ - --do_eval \ --evaluation_strategy steps \ --eval_steps ${eval_steps} \ --seed ${seed} \ @@ -136,7 +135,7 @@ source ~/anaconda3/bin/activate smoe --learning_rate ${lr} \ --weight_decay 0.1 \ --max_grad_norm 1.0 \ - --warmup_steps ${warmup_steps} \ + --warmup_steps 100 \ --max_steps ${max_steps} \ --max_train_samples ${max_train_samples} \ --save_strategy steps \ diff --git a/scripts/cpt/dynamic_data_selection/sheared_llama_32gpus.sh b/scripts/cpt/dynamic_data_selection/sheared_llama_32gpus.sh new file mode 100644 index 0000000..cb76356 --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/sheared_llama_32gpus.sh @@ -0,0 +1,164 @@ +#!/usr/bin/bash + +#SBATCH --job-name=llama2_random_scale4_32gpus_dynamic_data +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus_dynamic_data/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus_dynamic_data/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=4 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=4 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16, per-device bsz 4M tokens, lr 1e-4, 16gpus" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="1*10^9" + # warmup_tokens="0" + eval_tokens="1*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / $block_size" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + warmup_steps=$(echo "$warmup_tokens / $tokens_per_batch" | bc) + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + eval_steps=$(echo "$eval_tokens / $tokens_per_batch" | bc) + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_32gpus_dynamic_data" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo $comment > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --dynamic_data_selection "sheared_llama" \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/smoe/callbacks/tensorboard.py b/smoe/callbacks/tensorboard.py index 1bb0585..8a15fe7 100644 --- a/smoe/callbacks/tensorboard.py +++ b/smoe/callbacks/tensorboard.py @@ -93,10 +93,16 @@ def on_log( self.tb_writer.add_image( k, get_heatmap_img_grid_for_tb(v), state.global_step ) - elif k == "train/consumed_tokens": - v.update({"total_tokens": sum(v.values())}) + elif k == "train/prob_map" and isinstance(v, dict): for name, val in v.items(): self.tb_writer.add_scalar( - f"consumed_tokens/{name}", val, state.global_step + f"prob_map/{name}", val, state.global_step ) + + # elif k == "train/consumed_tokens": + # v.update({"total_tokens": sum(v.values())}) + # for name, val in v.items(): + # self.tb_writer.add_scalar( + # f"consumed_tokens/{name}", val, state.global_step + # ) self.tb_writer.flush() diff --git a/smoe/data/streaming.py b/smoe/data/streaming.py index 5020ae0..2c91995 100644 --- a/smoe/data/streaming.py +++ b/smoe/data/streaming.py @@ -219,29 +219,6 @@ def __len__(self): return len(self.cached) -def batchify_loader(dataset: Iterable, batch_size: int, collate_fn: Callable): - batch = [] - for ins in dataset: - batch.append(ins) - if len(batch) >= batch_size: - yield collate_fn(batch) - batch.clear() - if len(batch) > 0: - yield collate_fn(batch) - batch.clear() - - -class BufferAggregation: - def __init__(self, block_size: int) -> None: - self.block_size = block_size - - def __call__(self, buffer) -> Any: - results = buffer - if self.block_size > 0 and len(buffer) > 0: - results = group_instances(buffer, self.block_size) - return results - - class PackedJsonlDataset(IterableDataset): def __init__( self, @@ -249,71 +226,48 @@ def __init__( seed: int = 1227, buffer_size: int = 200, block_size: int = 2048, - skip_tokens: int = 0, ) -> None: super().__init__() self.data_dir = data_dir self.rng = random.Random(seed) self.buffer_size = buffer_size self.block_size = block_size - self.skip_tokens = skip_tokens data_dir_path = Path(data_dir) filepaths = sorted(data_dir_path.glob("**/*.jsonl")) self.rng.shuffle(filepaths) self.filepaths = filepaths - self.curr_filepath_pointer = -1 - self.consumed_tokens: int = 0 - - self.buffer_aggregation = BufferAggregation(self.block_size) - - def next_filepath(self) -> str: - if len(self.filepaths) == 0: - raise RuntimeError(f"There's no filepath in {self.data_dir}") - - self.curr_filepath_pointer += 1 - if self.curr_filepath_pointer >= len(self.filepaths): - self.curr_filepath_pointer = 0 - - num_skipped_filepath = 0 - while self.consumed_tokens < self.skip_tokens: - filepath = self.filepaths[self.curr_filepath_pointer] - # meta: [[current token number in the whole file, length of the current instance], ...] - meta: np.ndarray = np.load(filepath + META_SUFFIX) - curr_filepath_tokens = meta.sum(axis=0)[1] - if self.consumed_tokens + curr_filepath_tokens > self.skip_tokens: - break - self.consumed_tokens += curr_filepath_tokens - self.curr_filepath_pointer += 1 - if self.curr_filepath_pointer >= len(self.filepaths): - self.curr_filepath_pointer = 0 - num_skipped_filepath += 1 - - if num_skipped_filepath > 0: - logger.info( - f"Skip {num_skipped_filepath} files," - f" {self.consumed_tokens} tokens," - f" remaining {self.skip_tokens - self.consumed_tokens} tokens to skip." - ) - - return self.filepaths[self.curr_filepath_pointer] + self.buffer = [] def __iter__(self) -> Iterator: - filepath = self.next_filepath() - logger.debug(f"Iter over jsonl file: {filepath}") - ds = load_jsonlines_iter(filepath) - # if self.consumed_tokens < self.skip_tokens: - # remaining_skip_tokens = self.skip_tokens - self.consumed_tokens - # # zhutong: here, the skip method is not perfect since there is batch grouping, - # # and the final token number per instance may be different. - # num_skip_lines = (meta[:, 1].cumsum() > remaining_skip_tokens).nonzero()[0][0] - # ds.skip_lines(num_skip_lines) - # self.consumed_tokens += meta[:num_skip_lines].sum(axis=0)[1] - for batch in batchify_loader(ds, self.buffer_size, self.buffer_aggregation): - for ins in batch: - if self.consumed_tokens >= self.skip_tokens: - self.consumed_tokens += len(ins["input_ids"]) + self.buffer = [] + for filepath in self.filepaths: + logger.debug(f"Iter over jsonl file: {filepath}") + for ins in load_jsonlines_iter(filepath): + if self.buffer_size <= 1: yield ins + continue + + if len(self.buffer) >= self.buffer_size: + if len(self.buffer) > 0: + self.rng.shuffle(self.buffer) + self.buffer_aggregation() + yield from self.buffer + self.buffer.clear() + + self.buffer.append(ins) + + # for the last batch < buffer_size + if len(self.buffer) > 0: + self.rng.shuffle(self.buffer) + self.buffer_aggregation() + yield from self.buffer + self.buffer.clear() + + def buffer_aggregation(self): + if self.block_size > 0 and len(self.buffer) > 0: + results = group_instances(self.buffer, self.block_size) + self.buffer = results def state_dict(self): return { @@ -323,7 +277,6 @@ def state_dict(self): "buffer_size": self.buffer_size, "block_size": self.block_size, "filepaths": self.filepaths, - "consumed_tokens": self.consumed_tokens, } @@ -360,7 +313,6 @@ def __init__( seed: int = 1227, buffer_size: int = 200, block_size: int = 2048, - skip_tokens: dict = {}, ) -> None: self.rng = random.Random(seed) self.seed = seed @@ -391,7 +343,6 @@ def __init__( self.source2idx[task_type] = len(self.source2idx) self.prob_map[task_type] = sampling_weight - self.consumed_tokens = skip_tokens self.task_type_to_dataset = {} for task_type in task_types: # zhutong: use iter to support next() calling, since the dataset itself @@ -402,26 +353,10 @@ def __init__( seed=seed, buffer_size=buffer_size, block_size=block_size, - skip_tokens=skip_tokens.get(task_type, 0), ) ) self.task_type_to_dataset[task_type] = ds - def skip_tokens(self, skip_tokens: dict): - for task_type, num_skip_tokens in skip_tokens.items(): - self.task_type_to_dataset[task_type] = iter( - PackedJsonlDataset( - str(self.dataset_dir_path.joinpath(task_type)), - seed=self.seed, - buffer_size=self.buffer_size, - block_size=self.block_size, - skip_tokens=skip_tokens.get(task_type, 0), - ) - ) - if task_type not in self.consumed_tokens: - self.consumed_tokens[task_type] = 0 - self.consumed_tokens[task_type] += num_skip_tokens - def update_prob_map(self, new_prob_map: dict): self.prob_map.update(new_prob_map) @@ -436,11 +371,7 @@ def __iter__(self) -> Iterator: weights = [self.prob_map[task_type] for task_type in candidate_task_types] choice = self.rng.choices(candidate_task_types, weights=weights, k=1)[0] try: - ins = next(self.task_type_to_dataset[choice]) - if choice not in self.consumed_tokens: - self.consumed_tokens[choice] = 0 - self.consumed_tokens[choice] += len(ins["input_ids"]) - yield ins + yield next(self.task_type_to_dataset[choice]) except StopIteration: # self.task_type_to_dataset.pop(choice) # logger.debug(f"Task type {choice} finished, drop it") diff --git a/smoe/entrypoint/cpt/cpt_fpt.py b/smoe/entrypoint/cpt/cpt_fpt.py index 1ad0e6d..23a79b5 100644 --- a/smoe/entrypoint/cpt/cpt_fpt.py +++ b/smoe/entrypoint/cpt/cpt_fpt.py @@ -60,7 +60,7 @@ logger = logging.getLogger(__name__) -# @wechat_sender() +@wechat_sender() def main(): model_args, data_args, training_args = parse_args( ModelArguments, DataArguments, EnhancedTrainingArguments diff --git a/smoe/trainer/llama_lr_scheduling.py b/smoe/trainer/llama_lr_scheduling.py index 95c3a76..fea9c51 100644 --- a/smoe/trainer/llama_lr_scheduling.py +++ b/smoe/trainer/llama_lr_scheduling.py @@ -24,6 +24,8 @@ import torch.nn as nn from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader + +# from torch.profiler import profile, schedule, tensorboard_trace_handler, ProfilerActivity from transformers.debug_utils import DebugOption, DebugUnderflowOverflow from transformers.deepspeed import deepspeed_init from transformers.dependency_versions_check import dep_version_check @@ -342,7 +344,8 @@ def _maybe_log_save_evaluate( x.detach().cpu().tolist() for x in gate_importance ] logs["balance_loss"] = balance_loss.item() - logs["consumed_tokens"] = self.state.tot_consumed_tokens + logs["tot_consumed_tokens"] = self.state.tot_consumed_tokens + logs["prob_map"] = self.train_dataset.prob_map self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step @@ -825,6 +828,18 @@ def _inner_training_loop( steps_trained_in_current_epoch = 0 rng_to_sync = True + # tracing_schedule = schedule(skip_first=5, wait=5, warmup=2, active=2, repeat=1) + # trace_handler = tensorboard_trace_handler(dir_name="/mnt/petrelfs/zhutong/smoe/results/profiling", use_gzip=True) + + # with profile( + # activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + # schedule=tracing_schedule, + # on_trace_ready=trace_handler, + # profile_memory=True, + # record_shapes=True, + # with_stack=True + # ) as prof: + step = -1 for step, inputs in enumerate(epoch_iterator): total_batched_samples += 1 @@ -955,6 +970,7 @@ def _inner_training_loop( model.zero_grad() self.state.global_step += 1 + # prof.step() # self.state.consumed_tokens = self.train_dataset.consumed_tokens self.state.tot_consumed_tokens += self.args.num_tokens_per_batch self.state.epoch = ( diff --git a/smoe/utils/logging.py b/smoe/utils/logging.py index cdffdd8..90adecd 100644 --- a/smoe/utils/logging.py +++ b/smoe/utils/logging.py @@ -7,7 +7,7 @@ # Setup logging logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(filename)s - %(funcName)s - %(processName)s(%(process)d)/%(threadName)s(%(thread)d) %(message)s", + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index a65b30d..8d458a3 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -142,18 +142,18 @@ def test_weighted_streaming_loader(): if num_test_case <= 0: break assert len(batch["input_ids"]) == bsz - print( - f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" - ) + # print( + # f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" + # ) # assert sum(loader.dataset.consumed_tokens.values()) == (batch_idx + 1) * block_size print(loader.dataset.prob_map) num_test_case -= 1 lm_datasets.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) # loader.dataset.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) print(loader.dataset.prob_map) - print( - f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" - ) + # print( + # f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" + # ) def test_skip_tokens(): From 0a9e15ef35c6e1608cde6176a75979ec41c3011e Mon Sep 17 00:00:00 2001 From: zhutong Date: Thu, 16 Nov 2023 14:45:37 +0800 Subject: [PATCH 4/4] add connection testing --- .../sheared_llama_112gpus.sh | 9 +++--- scripts/test/test_conn.sh | 30 +++++++++++++++++++ tests/entrypoint/__init__.py | 0 tests/entrypoint/test_conn.py | 30 +++++++++++++++++++ 4 files changed, 65 insertions(+), 4 deletions(-) create mode 100644 scripts/test/test_conn.sh create mode 100644 tests/entrypoint/__init__.py create mode 100644 tests/entrypoint/test_conn.py diff --git a/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh index 14d3375..a24d857 100644 --- a/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh +++ b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh @@ -12,6 +12,7 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36 # reserved spot @@ -64,23 +65,23 @@ source ~/anaconda3/bin/activate smoe num_tokens="200*10^9" warmup_tokens="15*10^8" # warmup_tokens="0" - eval_tokens="1*10^9" + eval_tokens="2.5*10^9" seed=1227 deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json num_selects=4 max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) - max_train_samples=$(echo "${num_tokens} / $block_size" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) echo "max_steps: $max_steps" echo "max_train_samples: $max_train_samples" global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) echo "global batch size: $global_bs" tokens_per_batch=$(echo "$global_bs * $block_size" | bc) echo "#tokens/batch: $tokens_per_batch" - warmup_steps=$(echo "$warmup_tokens / $tokens_per_batch" | bc) + warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" - eval_steps=$(echo "$eval_tokens / $tokens_per_batch" | bc) + eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" data_cache=resources/cache diff --git a/scripts/test/test_conn.sh b/scripts/test/test_conn.sh new file mode 100644 index 0000000..dcf13bc --- /dev/null +++ b/scripts/test/test_conn.sh @@ -0,0 +1,30 @@ +#!/usr/bin/bash + +#SBATCH --job-name=test_conn +#SBATCH --output=logs/%x.log +#SBATCH --error=logs/%x.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=3 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -w SH-IDCA1404-10-140-54-11,SH-IDCA1404-10-140-54-36 + +export OMP_NUM_THREADS=4 + +nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) +nodes_array=($nodes) +head_node=${nodes_array[0]} + +srun torchrun \ + --nnodes 3 \ + --nproc_per_node 8 \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29520 \ + tests/entrypoint/test_conn.py diff --git a/tests/entrypoint/__init__.py b/tests/entrypoint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/entrypoint/test_conn.py b/tests/entrypoint/test_conn.py new file mode 100644 index 0000000..c3888a8 --- /dev/null +++ b/tests/entrypoint/test_conn.py @@ -0,0 +1,30 @@ +import os +import socket + +import torch +import torch.distributed as dist +import torch.nn as nn + +# from accelerate import Accelerator + + +def test_connection(): + string = f"{socket.gethostname()} - MASTER_ADDR: {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']} - WORLD_SIZE: {os.environ['WORLD_SIZE']} - RANK: {os.environ['RANK']}" + print(string) + dist.init_process_group("nccl") + # ac = Accelerator() + m = nn.Linear(5, 10) + m = nn.parallel.DistributedDataParallel(m, device_ids=[dist.get_rank()]) + # m = ac.prepare_model(m) + x = torch.randn(3, 5, device=m.device) + y = m(x) + # dist.all_reduce(y, op=dist.ReduceOp.SUM) + assert y.shape == (3, 10) + # print(f"Done - local: {ac.local_process_index} - rank: {ac.process_index} - world: {ac.num_processes}") + print( + f"Done - {socket.gethostname()} - local: {os.environ['LOCAL_RANK']} - rank: {dist.get_rank()} - world: {dist.get_world_size()}" + ) + + +if __name__ == "__main__": + test_connection()