From 2c0efb621ca5b27e36246308e6020ded99af7f34 Mon Sep 17 00:00:00 2001 From: ShinoharaHare Date: Fri, 4 Oct 2024 19:41:03 +0800 Subject: [PATCH 1/5] =?UTF-8?q?refactor:=20=E7=B5=B1=E4=B8=80=E8=BC=B8?= =?UTF-8?q?=E5=87=BA=E8=B3=87=E6=96=99=E9=9B=86=E8=B3=87=E8=A8=8A=E4=BB=8B?= =?UTF-8?q?=E9=9D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/pre_process_data.py | 55 ++++++++++++++ scripts/pre_process_pre_training_data.py | 76 ------------------- src/llm_training/data/__init__.py | 1 + src/llm_training/data/base_datamodule.py | 23 +++++- .../pre_training/pre_training_datamodule.py | 47 +++++++++++- 5 files changed, 120 insertions(+), 82 deletions(-) create mode 100644 scripts/pre_process_data.py delete mode 100644 scripts/pre_process_pre_training_data.py diff --git a/scripts/pre_process_data.py b/scripts/pre_process_data.py new file mode 100644 index 0000000..daa6b0b --- /dev/null +++ b/scripts/pre_process_data.py @@ -0,0 +1,55 @@ +import os +import sys +from typing import TextIO + +from llm_training.data import * +from llm_training.models import * +from llm_training.overrides.cli import * + + +class OutputStreamRedirector: + def __init__(self, *streams: TextIO) -> None: + self._streams = streams + + def write(self, s: str) -> int: + n = 0 + for stream in self._streams: + n += stream.write(s) + return n + + def flush(self) -> None: + for s in self._streams: + s.flush() + + +def main(): + cli = LightningCLI(run=False) + + datamodule = cli.datamodule + + assert isinstance(datamodule, HFBasedDataModule) + + config = datamodule.config + + enable_cache = config.enable_cache + pre_processed_data_path = config.pre_processed_data_path + + if not os.path.exists(pre_processed_data_path) or len(os.listdir(pre_processed_data_path)) == 0: + config.pre_processed_data_path = None + config.enable_cache = True + datamodule.setup() + datamodule.save_pre_processed_data(pre_processed_data_path) + else: + print(f'`pre_processed_data_path="{pre_processed_data_path}"` is not empty, skipping.') + datamodule.setup() + + if not enable_cache: + n = datamodule.cleanup_cache_files() + print(f'Cleanup cache files: {n}') + + with open(os.path.join(pre_processed_data_path, 'info.txt'), 'w') as f: + datamodule.print_dataset_info(file=OutputStreamRedirector(sys.stdout, f)) + + +if __name__ == '__main__': + main() diff --git a/scripts/pre_process_pre_training_data.py b/scripts/pre_process_pre_training_data.py deleted file mode 100644 index 7a9ef39..0000000 --- a/scripts/pre_process_pre_training_data.py +++ /dev/null @@ -1,76 +0,0 @@ -import os -from collections import Counter - -from datasets import DatasetDict -from tabulate import tabulate -from tqdm.auto import tqdm - -from llm_training.data import * -from llm_training.models import * -from llm_training.overrides.cli import * - - -def get_tokens_table(dataset_dict: DatasetDict) -> str: - tokens: dict[str, Counter[str, int]] = {} - for k, dataset in dataset_dict.items(): - counter = Counter() - dataset = dataset.select_columns(['source', 'length']) - progress = tqdm(total=len(dataset), desc=f'Count tokens ({k})') - for batch in dataset.iter(1000): - batch_size = len(batch['length']) - for source, length in zip(batch['source'], batch['length']): - counter[source] += length - counter['all'] += length - tokens[k] = counter - progress.set_postfix(tokens=counter['all']) - progress.update(batch_size) - progress.clear() - - return tabulate( - [ - [split, source, tokens] - for split, counter in tokens.items() - for source, tokens in counter.most_common() - ], - headers=['Split', 'Source', 'Tokens'], - tablefmt='orgtbl' - ) - - -def main(): - cli = LightningCLI(run=False) - - datamodule: PreTrainingDataModule = cli.datamodule - config = datamodule.config - - enable_cache = config.enable_cache - pre_processed_data_path = config.pre_processed_data_path - - if not os.path.exists(pre_processed_data_path) or len(os.listdir(pre_processed_data_path)) == 0: - config.pre_processed_data_path = None - config.enable_cache = True - datamodule.setup() - datamodule.save_pre_processed_data(pre_processed_data_path) - else: - print(f'`pre_processed_data_path="{pre_processed_data_path}"` is not empty, skipping.') - datamodule.setup() - - if not enable_cache: - datamodule.cleanup_cache_files() - - table_string = ( - 'Original Tokens:\n' - + get_tokens_table(datamodule.pre_processed_dataset_dict) - + '\n\n' - + 'Sampled Tokens:\n' - + get_tokens_table(datamodule.dataset_dict) - ) - - with open(os.path.join(pre_processed_data_path, 'tokens.txt'), 'w') as f: - f.write(table_string + '\n') - - print(table_string) - - -if __name__ == '__main__': - main() diff --git a/src/llm_training/data/__init__.py b/src/llm_training/data/__init__.py index fd78ef0..33354a7 100644 --- a/src/llm_training/data/__init__.py +++ b/src/llm_training/data/__init__.py @@ -2,6 +2,7 @@ from .base_datamodule import BaseDataModule from .base_datamodule_config import BaseDataModuleConfig from .dummy import * +from .hf_based import * from .instruction_tuning import * from .pre_training import * from .preference_tuning import * diff --git a/src/llm_training/data/base_datamodule.py b/src/llm_training/data/base_datamodule.py index 0802684..8d53737 100644 --- a/src/llm_training/data/base_datamodule.py +++ b/src/llm_training/data/base_datamodule.py @@ -1,6 +1,7 @@ import logging +import os from functools import partial -from typing import Mapping +from typing import Mapping, TextIO import lightning as L from torch.utils.data import DataLoader, Dataset @@ -26,9 +27,7 @@ def __init__(self, config: BaseDataModuleConfig) -> None: self.raw_dataset_dict = None self.pre_processed_dataset_dict = None self.dataset_dict = None - - self.train_dataloader_state = {} - + def load_data(self) -> DatasetDict: raise NotImplementedError() @@ -53,6 +52,22 @@ def save_pre_processed_data(self, path: str | None = None) -> None: def load_pre_processed_data(self, path: str | None = None) -> None: raise NotImplementedError() + def print_dataset_info(self, file: TextIO | None = None) -> None: + print_ = partial(print, file=file) + def print_header(header: str) -> None: + n = os.get_terminal_size().columns + m = (n - len(header) - 2) // 2 + divider = '=' * m + header = f'{divider} {header} {divider}' + print_(f'{header:^{n}}', end='\n\n') + + print_header('Raw Dataset') + print_(self.raw_dataset_dict, end='\n\n') + print_header('Pre-processed Dataset') + print_(self.pre_processed_dataset_dict, end='\n\n') + print_header('Final Dataset') + print_(self.dataset_dict, end='\n\n') + def _get_dataloader(self, split: str): dataloader_class = DataLoader dataloader_kwargs = dict( diff --git a/src/llm_training/data/pre_training/pre_training_datamodule.py b/src/llm_training/data/pre_training/pre_training_datamodule.py index 3d867d7..3060984 100644 --- a/src/llm_training/data/pre_training/pre_training_datamodule.py +++ b/src/llm_training/data/pre_training/pre_training_datamodule.py @@ -3,14 +3,17 @@ import os import pickle import random +from collections import Counter +from functools import partial from typing import Any, Iterable from datasets import Dataset +from tabulate import tabulate from tqdm.auto import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from llm_training.data.hf_based.hf_based_datamodule import (DatasetDict, - HFBasedDataModule) + HFBasedDataModule) from .pre_training_datacollator import PreTrainingDataCollator from .pre_training_datamodule_config import (ConcatMethod, @@ -158,6 +161,46 @@ def post_process_data(self, dataset_dict: DatasetDict) -> DatasetDict: dataset_dict = self.sample_data(dataset_dict, source_indices) return dataset_dict + def print_dataset_info(self, file: str | None) -> None: + super().print_dataset_info(file) + + print_ = partial(print, file=file) + + def get_tokens_table(dataset_dict: DatasetDict) -> str: + tokens: dict[str, Counter[str, int]] = {} + for k, dataset in dataset_dict.items(): + counter = Counter() + dataset = dataset.select_columns(['source', 'length']) + with tqdm( + total=len(dataset), + desc=f'Count tokens ({k})', + leave=False + ) as progress: + for batch in dataset.iter(1000): + batch_size = len(batch['length']) + for source, length in zip(batch['source'], batch['length']): + counter[source] += length + counter['*'] += length + tokens[k] = counter + progress.set_postfix(tokens=counter['*']) + progress.update(batch_size) + + return tabulate( + [ + [split, source, tokens] + for split, counter in tokens.items() + for source, tokens in counter.most_common() + ], + headers=['Split', 'Source', 'Tokens'], + tablefmt='orgtbl' + ) + + print_('=' * os.get_terminal_size().columns, end='\n\n') + print_('Original Tokens:') + print_(get_tokens_table(self.pre_processed_dataset_dict), end='\n\n') + print_('Sampled Tokens:') + print_(get_tokens_table(self.dataset_dict)) + def _tokenize( batch: dict[str, list[str | Any]], @@ -180,7 +223,7 @@ def _tokenize( batch_text = [batch['text'][i] for i in selected_indices] batch = {k: [batch[k][i] for i in selected_indices] for k in keep_columns} - batch['input_ids'] = tokenizer( + batch['input_ids'] = tokenizer.batch_encode_plus( batch_text, add_special_tokens=False, return_token_type_ids=False, From e60b5bf24d9f260816ebd8ecd95b6a4b4a69dc9b Mon Sep 17 00:00:00 2001 From: ShinoharaHare Date: Fri, 4 Oct 2024 19:41:53 +0800 Subject: [PATCH 2/5] =?UTF-8?q?feat:=20=E6=94=AF=E6=8F=B4=20split=20slicin?= =?UTF-8?q?g=E3=80=81=E6=94=B9=E5=9C=A8=20load=5Fdata=20=E6=B8=85=E9=99=A4?= =?UTF-8?q?=E5=BF=AB=E5=8F=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_training/data/hf_based/hf_based_datamodule.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/llm_training/data/hf_based/hf_based_datamodule.py b/src/llm_training/data/hf_based/hf_based_datamodule.py index 90b774e..1d92bab 100644 --- a/src/llm_training/data/hf_based/hf_based_datamodule.py +++ b/src/llm_training/data/hf_based/hf_based_datamodule.py @@ -41,11 +41,15 @@ def load_data(self) -> DatasetDict: dataset_dict = load_dataset(**dataset_kwargs) - if (split := dataset_kwargs.get('split', False)): - dataset_dict = DatasetDict({split: dataset_dict}) + if isinstance(dataset_dict, Dataset): + dataset_dict = DatasetDict({'train': dataset_dict}) assert self.config.validation_split is None or 'train' in dataset_dict and 'validation' not in dataset_dict + if self.config.cleanup_cache_files: + n = dataset_dict.cleanup_cache_files() + logger.info(f'Cleanup cache files: {n}') + return dataset_dict def split_data(self, dataset_dict: DatasetDict): @@ -58,8 +62,6 @@ def prepare_data(self) -> None: if self.config.pre_processed_data_path is None: with cache_context(self.config.enable_cache): dataset_dict = self.load_data() - if self.config.cleanup_cache_files: - dataset_dict.cleanup_cache_files() self.pre_process_data(dataset_dict) def setup(self, stage: str | None = None) -> None: From a4c1618103624f7d42e6620acac75879c6af753e Mon Sep 17 00:00:00 2001 From: ShinoharaHare Date: Fri, 4 Oct 2024 19:42:40 +0800 Subject: [PATCH 3/5] =?UTF-8?q?doc:=20=E6=9B=B4=E6=96=B0=E9=A0=90=E8=99=95?= =?UTF-8?q?=E7=90=86=E8=85=B3=E6=9C=AC=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/pre_training.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/pre_training.md b/docs/pre_training.md index cb8c1c6..7fc9472 100644 --- a/docs/pre_training.md +++ b/docs/pre_training.md @@ -21,7 +21,7 @@ For a complete set of parameters, please refer to [`PreTrainingDataModuleConfig` Before training begins, the framework automatically processes the data, ensuring everything is ready before training starts. This is particularly convenient when dealing with small training dataset. However, pre-training datasets are typically large, and CPUs used during training are often limited, making this step very time-consuming. -To address this issue, you can set `pre_processed_data_path` and use many CPUs to execute `scripts/pre_process_pre_training_data.py` for pre-processing and saving the data in advance. +To address this issue, you can set `pre_processed_data_path` and use many CPUs to execute `scripts/pre_process_data.py` for pre-processing and saving the data in advance. Remember to set `num_proc` to the desired number of CPUs to utilize. @@ -35,7 +35,7 @@ data: ``` ```bash -python scripts/pre_process_pre_training_data.py -c +python scripts/pre_process_data.py -c ``` ## Data Sampling From f19e55ba926450eaf68e9cb4ccff1dec5828ae0e Mon Sep 17 00:00:00 2001 From: ShinoharaHare Date: Fri, 4 Oct 2024 19:44:26 +0800 Subject: [PATCH 4/5] =?UTF-8?q?doc:=20=E6=96=B0=E5=A2=9E=E9=A0=90=E8=A8=AD?= =?UTF-8?q?=20`chat=5Ftemplate`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/examples/phi-3/phi-3-mini_dpo_example.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/examples/phi-3/phi-3-mini_dpo_example.yaml b/config/examples/phi-3/phi-3-mini_dpo_example.yaml index 6a905eb..add7881 100644 --- a/config/examples/phi-3/phi-3-mini_dpo_example.yaml +++ b/config/examples/phi-3/phi-3-mini_dpo_example.yaml @@ -60,6 +60,7 @@ data: class_path: HFTokenizer init_args: path: microsoft/Phi-3-mini-128k-instruct + chat_template: phi-3 batch_size: 1 max_length: 4096 pad_to_multiple_of: 64 From 913578e81e76915c7631d9071c4b857c10da9b71 Mon Sep 17 00:00:00 2001 From: ShinoharaHare Date: Fri, 4 Oct 2024 19:46:51 +0800 Subject: [PATCH 5/5] =?UTF-8?q?feat:=20=E6=94=B9=E7=94=A8=20`return=5Fassi?= =?UTF-8?q?stant=5Ftokens=5Fmask`=20=E4=BE=86=20mask=20=E6=8E=89=20user=20?= =?UTF-8?q?prompt=20=E7=9A=84=20label?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../data/chat_templates/__init__.py | 3 +- .../data/chat_templates/chatml.j2 | 25 ++-- src/llm_training/data/chat_templates/gemma.j2 | 36 ++--- .../data/chat_templates/llama-2.j2 | 51 ++++--- .../data/chat_templates/llama-3.j2 | 34 +++-- src/llm_training/data/chat_templates/phi-3.j2 | 34 +++-- .../data/chat_templates/tulu-2.j2 | 41 ++---- .../instruction_tuning_datamodule.py | 88 ++++++------ .../preference_tuning_datamodule.py | 128 ++++++++---------- 9 files changed, 205 insertions(+), 235 deletions(-) diff --git a/src/llm_training/data/chat_templates/__init__.py b/src/llm_training/data/chat_templates/__init__.py index a8a7daa..11ac6ad 100644 --- a/src/llm_training/data/chat_templates/__init__.py +++ b/src/llm_training/data/chat_templates/__init__.py @@ -5,7 +5,8 @@ class _ChatTemplates: def _get_path_by_name(self, name: str) -> Path: - return Path(__file__).parent.joinpath(name).with_suffix('.j2') + p = Path(__file__).parent / name + return p.with_suffix(f'{p.suffix}.j2') def __getitem__(self, name: str) -> str: if name not in self: diff --git a/src/llm_training/data/chat_templates/chatml.j2 b/src/llm_training/data/chat_templates/chatml.j2 index 15bbd29..9eeb7ae 100644 --- a/src/llm_training/data/chat_templates/chatml.j2 +++ b/src/llm_training/data/chat_templates/chatml.j2 @@ -1,11 +1,14 @@ -{% set is_splitted = index is defined and length is defined %} -{% for message in messages %} - {% set content = message['content'] + '<|im_end|>\n' %} - {% if message['role'] != 'assistant' or not is_splitted %} - {% set content = '<|im_start|>' + message['role'] + '\n' + content %} - {% endif %} - {{- content -}} -{% endfor %} -{% if add_generation_prompt %} - {{- '<|im_start|>assistant\n' -}} -{% endif %} +{%- for message in messages %} + {{- '<|im_start|>' + message.role + '\n' }} + {%- set content = message.content + '<|im_end|>\n' %} + {%- if message.role == 'assistant' %} + {% generation %} + {{- content -}} + {% endgeneration %} + {%- else %} + {{- content }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/src/llm_training/data/chat_templates/gemma.j2 b/src/llm_training/data/chat_templates/gemma.j2 index 4af92b3..88c1b43 100644 --- a/src/llm_training/data/chat_templates/gemma.j2 +++ b/src/llm_training/data/chat_templates/gemma.j2 @@ -1,21 +1,15 @@ -{% set is_splitted = index is defined and length is defined %} -{% for message in messages %} - {% set index = index if is_splitted else loop.index0 %} - {% set length = length if is_splitted else messages|length %} - {% set role = message['role'] if message['role'] != 'assistant' else 'model' %} - {% set header = '' + role + '\n' %} - {% set content = message['content'] | trim + '\n' %} - {% if message['role'] != 'assistant' or not is_splitted %} - {% set content = header + content %} - {% endif %} - {% if index == 0 and loop.index0 == 0 %} - {% set content = bos_token + content %} - {% endif %} - {% if index == length - 1 and loop.index0 == messages|length - 1 and message['role'] == 'assistant' %} - {% set content = content + eos_token %} - {% endif %} - {{- content -}} -{% endfor %} -{% if add_generation_prompt %} - {{- 'model\n' -}} -{% endif %} +{{- bos_token }} +{%- for message in messages %} + {%- set content = message.content | trim + '\n' %} + {%- set role = 'model' if message.role == 'assistant' else message.role %} + {{- '' + role + '\n' }} + {%- if message.role == 'assistant' %} + {% generation %} + {{- content }} + {% endgeneration %} + {%- else %} + {{- content }} + {%- endif %} +{%- endfor %} + {%- if add_generation_prompt %}{{'model\n'}} +{%- endif %} diff --git a/src/llm_training/data/chat_templates/llama-2.j2 b/src/llm_training/data/chat_templates/llama-2.j2 index e80f075..70f78d9 100644 --- a/src/llm_training/data/chat_templates/llama-2.j2 +++ b/src/llm_training/data/chat_templates/llama-2.j2 @@ -1,27 +1,24 @@ -{% if messages[0]['role'] == 'system' %} - {% set loop_messages = messages[1:] %} - {% set system_message = messages[0]['content'] %} -{% else %} - {% set loop_messages = messages %} - {% set system_message = false %} -{% endif %} -{% if system_message != false and (loop_messages|length == 0 or loop_messages|length == 1 and loop_messages[0]['role'] != 'user') %} - {{ raise_exception('The system prompt must be passed along with a user prompt.') }} -{% endif %} -{% for message in loop_messages %} - {% if loop_messages|length > 1 and (message['role'] == 'user') != (loop.index0 % 2 == 0) %} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {% endif %} - {% if loop.index0 == 0 and system_message != false %} - {% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %} - {% else %} - {% set content = message['content'] %} - {% endif %} - {% if message['role'] == 'user' %} - {{- bos_token + '[INST] ' + content.strip() + ' [/INST]' -}} - {% elif message['role'] == 'system' %} - {{- '<>\n' + content.strip() + '\n<>\n\n' -}} - {% elif message['role'] == 'assistant' %} - {{- ' ' + content.strip() + ' ' + eos_token -}} - {% endif %} -{% endfor %} +{%- if messages[0]['role'] == 'system' %} + {%- set loop_messages = messages[1:] %} + {%- set system_message = messages[0]['content'] %} +{%- else %} + {%- set loop_messages = messages %} + {%- set system_message = false %} +{%- endif %} +{%- for message in loop_messages %} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{- raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {%- endif %} + {%- if loop.index0 == 0 and system_message != false %} + {%- set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %} + {%- else %} + {%- set content = message['content'] %} + {%- endif %} + {%- if message['role'] == 'user' %} + {{- bos_token + '[INST] ' + content.strip() + ' [/INST]' }} + {%- elif message['role'] == 'assistant' %} + {%- generation %} + {{- ' ' + content.strip() + ' ' + eos_token }} + {%- endgeneration %} + {%- endif %} +{%- endfor %} diff --git a/src/llm_training/data/chat_templates/llama-3.j2 b/src/llm_training/data/chat_templates/llama-3.j2 index 14af2c5..75cf529 100644 --- a/src/llm_training/data/chat_templates/llama-3.j2 +++ b/src/llm_training/data/chat_templates/llama-3.j2 @@ -1,15 +1,19 @@ -{% set is_splitted = index is defined and length is defined %} -{% for message in messages %} - {% set header = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' %} - {% set content = message['content'] | trim + '<|eot_id|>' %} - {% if message['role'] != 'assistant' or not is_splitted %} - {% set content = header + content %} - {% endif %} - {% if (is_splitted and index == 0 and loop.index0 == 0) or (not is_splitted and loop.index0 == 0) %} - {% set content = bos_token + content %} - {% endif %} - {{- content -}} -{% endfor %} -{% if add_generation_prompt %} - {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} -{% endif %} +{%- set loop_messages = messages %} +{%- for message in loop_messages %} + {%- set header = '<|start_header_id|>' + message.role + '<|end_header_id|>\n\n' %} + {%- set content = message.content | trim + '<|eot_id|>' %} + {%- if loop.index0 == 0 %} + {%- set header = bos_token + header %} + {%- endif %} + {{- header -}} + {%- if message.role == 'assistant' %} + {% generation %} + {{- content -}} + {% endgeneration %} + {%- else %} + {{- content }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/src/llm_training/data/chat_templates/phi-3.j2 b/src/llm_training/data/chat_templates/phi-3.j2 index e5a40ac..7c1452d 100644 --- a/src/llm_training/data/chat_templates/phi-3.j2 +++ b/src/llm_training/data/chat_templates/phi-3.j2 @@ -1,14 +1,20 @@ -{% set is_splitted = index is defined and length is defined %} -{% for message in messages %} - {% set content = message['content'] + '<|end|>\n' %} - {% if message['role'] != 'assistant' or not is_splitted %} - {% set content = '<|' + message['role'] + '|>\n' + content %} - {% endif %} - {% if (is_splitted and index == 0 and loop.index0 == 0) or (not is_splitted and loop.index0 == 0) %} - {% set content = bos_token + content %} - {% endif %} - {{- content -}} -{% endfor %} -{% if add_generation_prompt %} - {{- '<|assistant|>\n' -}} -{% endif %} +{%- for message in messages %} + {%- set content = message.content | trim %} + {%- if message.role == 'system' and content %} + {{- '<|system|>\n' + content + '<|end|>\n' }} + {%- elif message.role == 'user' %} + {{- '<|user|>\n' + content + '<|end|>\n' }} + {%- elif message.role == 'assistant' %} + {{- '<|assistant|>\n' -}} + {% generation %} + {{- content + '<|end|>\n' -}} + {% endgeneration %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|assistant|>\n' }} +{%- else %} + {% generation %} + {{- eos_token -}} + {% endgeneration %} +{%- endif %} diff --git a/src/llm_training/data/chat_templates/tulu-2.j2 b/src/llm_training/data/chat_templates/tulu-2.j2 index a419f24..c6eccdb 100644 --- a/src/llm_training/data/chat_templates/tulu-2.j2 +++ b/src/llm_training/data/chat_templates/tulu-2.j2 @@ -1,27 +1,14 @@ -{% if messages[0]['role'] == 'system' %} - {% set loop_messages = messages[1:] %} - {% set system_message = messages[0]['content'] %} -{% else %} - {% set loop_messages = messages %} - {% set system_message = false %} -{% endif %} -{% if system_message != false and (loop_messages|length == 0 or loop_messages|length == 1 and loop_messages[0]['role'] != 'user') %} - {{ raise_exception('The system prompt must be passed along with a user prompt.') }} -{% endif %} -{% for message in loop_messages %} - {% if loop_messages|length > 1 and (message['role'] == 'user') != (loop.index0 % 2 == 0) %} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {% endif %} - {% if loop.index0 == 0 and system_message != false %} - {% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %} - {% else %} - {% set content = message['content'] %} - {% endif %} - {% if message['role'] == 'user' %} - {{- bos_token + '<|user|>\n' + content.strip() + '\n<|assistant|>\n' -}} - {% elif message['role'] == 'system' %} - {{- '<>\n' + content.strip() + '\n<>\n\n' -}} - {% elif message['role'] == 'assistant' %} - {{- content.strip() + eos_token -}} - {% endif %} -{% endfor %} +{%- for message in messages %} + {%- if message.role == 'system' %} + {{- '<|system|>\n' + message.content }} + {%- elif message.role == 'user' %} + {{- '<|user|>\n' + message.content }} + {%- elif message.role == 'assistant' %} + {% generation %} + {{- '<|assistant|>\n' + message.content + eos_token -}} + {% endgeneration %} + {%- endif %} + {%- if loop.last and add_generation_prompt %} + {{- '<|assistant|>' }} + {%- endif %} +{%- endfor %} diff --git a/src/llm_training/data/instruction_tuning/instruction_tuning_datamodule.py b/src/llm_training/data/instruction_tuning/instruction_tuning_datamodule.py index f231741..6b6d7a1 100644 --- a/src/llm_training/data/instruction_tuning/instruction_tuning_datamodule.py +++ b/src/llm_training/data/instruction_tuning/instruction_tuning_datamodule.py @@ -21,16 +21,16 @@ def pre_process_data(self, dataset_dict: DatasetDict) -> DatasetDict: dataset_dict = self.map_dataset_dict( dataset_dict, _apply_chat_template_and_tokenize, - input_columns='messages', - remove_columns=True, fn_kwargs=dict( tokenizer=self.config.tokenizer, chat_template=self.config.chat_template, add_default_system_prompt_rate=self.config.add_default_system_prompt_rate, default_system_prompt=self.config.default_system_prompt ), + batched=True, + remove_columns=True, num_proc=self.config.num_proc, - desc='Apply template and tokenize' + desc='Apply chat template and tokenize' ) if self.config.overlong_handling_method == OverlongHandlingMethod.DROP: @@ -67,60 +67,48 @@ def pre_process_data(self, dataset_dict: DatasetDict) -> DatasetDict: def _apply_chat_template_and_tokenize( - messages: list[dict[str, str]], + batch: dict[str, list[str]], tokenizer: PreTrainedTokenizerBase, chat_template: str | None, default_system_prompt: str | None, add_default_system_prompt_rate: float | None ): - input_ids = [] - labels = [] - - # Add an empty system prompt randomly if it does not exist. - has_system_prompt = any(m['role'] == 'system' for m in messages) - if ( - not has_system_prompt - and default_system_prompt is not None - and add_default_system_prompt_rate is not None - and random.random() < add_default_system_prompt_rate - ): - messages.insert(0, {'role': 'system', 'content': default_system_prompt}) - - system_prompt = None - if messages[0]['role'] == 'system': - system_prompt = messages.pop(0) - - for i, message in enumerate(messages): - conversation = [message] - if i == 0 and system_prompt is not None: - conversation.insert(0, system_prompt) - text = tokenizer.apply_chat_template( - conversation, - chat_template=chat_template, - tokenize=False, - add_generation_prompt=message['role'] == 'user', - index=i, - length=len(messages) + new_batch = { + 'input_ids': [], + 'labels': [] + } + + for messages in batch['messages']: + # Add an empty system prompt randomly if it does not exist. + has_system_prompt = any(m['role'] == 'system' for m in messages) + if ( + not has_system_prompt + and default_system_prompt is not None + and add_default_system_prompt_rate is not None + and random.random() < add_default_system_prompt_rate + ): + messages.insert(0, {'role': 'system', 'content': default_system_prompt}) + + batch_encoding = tokenizer.apply_chat_template( + batch['messages'], + chat_template=chat_template, + return_dict=True, + return_assistant_tokens_mask=True, + tokenizer_kwargs=dict( + return_attention_mask=False, + verbose=False ) - # 這裡將同一筆資料分多次 tokenize,為保證跟一次 tokenize 全部的結果相同 - # 先在前面加一個 token,encode 後再移除掉 - text = tokenizer.bos_token + text - current_input_ids = tokenizer.encode(text, add_special_tokens=False) - current_input_ids = current_input_ids[1:] - - if message['role'] in ['system', 'user']: - input_ids += current_input_ids - labels += [-100] * len(current_input_ids) - elif message['role'] == 'assistant': - input_ids += current_input_ids - labels += current_input_ids - else: - raise ValueError(f"Unknown role: `{message['role']}`") + ) + + for input_ids, assistant_masks in zip( + batch_encoding['input_ids'], + batch_encoding['assistant_masks'] + ): + labels = [i if a == 1 else -100 for i, a in zip(input_ids, assistant_masks)] + new_batch['input_ids'].append(input_ids) + new_batch['labels'].append(labels) - return { - 'input_ids': input_ids, - 'labels': labels - } + return new_batch def _drop_overlong(input_ids: list[int], max_length: int): diff --git a/src/llm_training/data/preference_tuning/preference_tuning_datamodule.py b/src/llm_training/data/preference_tuning/preference_tuning_datamodule.py index 270943c..a5f53b6 100644 --- a/src/llm_training/data/preference_tuning/preference_tuning_datamodule.py +++ b/src/llm_training/data/preference_tuning/preference_tuning_datamodule.py @@ -1,5 +1,3 @@ -from typing import Any - from transformers import PreTrainedTokenizerBase from llm_training.data.hf_based.hf_based_datamodule import (DatasetDict, @@ -21,11 +19,12 @@ def pre_process_data(self, dataset_dict: DatasetDict) -> DatasetDict: dataset_dict = self.map_dataset_dict( dataset_dict, _apply_chat_template_and_tokenize, - remove_columns=True, fn_kwargs=dict( tokenizer=self.config.tokenizer, chat_template=self.config.chat_template ), + batched=True, + remove_columns=True, num_proc=self.config.num_proc, desc='Apply chat template and tokenize' ) @@ -51,79 +50,70 @@ def pre_process_data(self, dataset_dict: DatasetDict) -> DatasetDict: return dataset_dict -def _apply_chat_template_and_tokenize_single( - messages: list[dict[str, str]], - tokenizer: PreTrainedTokenizerBase, - chat_template: str | None -) -> tuple[list[int], list[int]]: - input_ids = [] - labels = [] - - system_prompt = None - if messages[0]['role'] == 'system': - system_prompt = messages.pop(0) - - for i, message in enumerate(messages): - conversation = [message] - if i == 0 and system_prompt is not None: - conversation.insert(0, system_prompt) - text = tokenizer.apply_chat_template( - conversation, - chat_template=chat_template, - tokenize=False, - add_generation_prompt=message['role'] == 'user', - index=i, - length=len(messages) - ) - # 這裡將同一筆資料分多次 tokenize,為保證跟一次 tokenize 全部的結果相同 - # 先在前面加一個 token,encode 後再移除掉 - text = tokenizer.bos_token + text - current_input_ids = tokenizer.encode(text, add_special_tokens=False) - current_input_ids = current_input_ids[1:] - - if message['role'] in ['system', 'user']: - input_ids += current_input_ids - labels += [-100] * len(current_input_ids) - elif message['role'] == 'assistant': - input_ids += current_input_ids - labels += current_input_ids - else: - raise ValueError(f"Unknown role: `{message['role']}`") - - return input_ids, labels - - def _apply_chat_template_and_tokenize( - example: dict[str, Any], + batch: dict[str, list[str]], tokenizer: PreTrainedTokenizerBase, chat_template: str | None ): - chosen_input_ids, chosen_labels = _apply_chat_template_and_tokenize_single( - [ - {'role': 'user', 'content': example['prompt']}, - {'role': 'assistant', 'content': example['chosen']} - ], - tokenizer=tokenizer, - chat_template=chat_template - ) + new_batch = { + 'chosen_input_ids': [], + 'chosen_labels': [], + 'chosen_length': [], + 'rejected_input_ids': [], + 'rejected_labels': [], + 'rejected_length': [] + } - rejected_input_ids, rejected_labels = _apply_chat_template_and_tokenize_single( - [ - {'role': 'user', 'content': example['prompt']}, - {'role': 'assistant', 'content': example['rejected']} - ], - tokenizer=tokenizer, - chat_template=chat_template + chosen_messages = [] + rejected_messages = [] + for prompt, chosen, rejected in zip( + batch['prompt'], + batch['chosen'], + batch['rejected'] + ): + chosen_messages.append([ + {'role': 'user', 'content': prompt}, + {'role': 'assistant', 'content': chosen} + ]) + + rejected_messages.append([ + {'role': 'user', 'content': prompt}, + {'role': 'assistant', 'content': rejected} + ]) + + kwargs = dict( + chat_template=chat_template, + return_dict=True, + return_assistant_tokens_mask=True, + tokenizer_kwargs=dict( + return_attention_mask=False, + verbose=False + ) ) - return { - 'chosen_input_ids': chosen_input_ids, - 'chosen_labels': chosen_labels, - 'chosen_length': len(chosen_input_ids), - 'rejected_input_ids': rejected_input_ids, - 'rejected_labels': rejected_labels, - 'rejected_length': len(rejected_input_ids), - } + chosen_batch_encoding = tokenizer.apply_chat_template(chosen_messages, **kwargs) + for input_ids, assistant_masks in zip( + chosen_batch_encoding['input_ids'], + chosen_batch_encoding['assistant_masks'] + ): + labels = [i if a == 1 else -100 for i, a in zip(input_ids, assistant_masks)] + i = input_ids.index(32001) + assert assistant_masks[i] == 0 + new_batch['chosen_input_ids'].append(input_ids) + new_batch['chosen_labels'].append(labels) + new_batch['chosen_length'].append(len(input_ids)) + + rejected_batch_encoding = tokenizer.apply_chat_template(rejected_messages, **kwargs) + for input_ids, assistant_masks in zip( + rejected_batch_encoding['input_ids'], + rejected_batch_encoding['assistant_masks'] + ): + labels = [i if a == 1 else -100 for i, a in zip(input_ids, assistant_masks)] + new_batch['rejected_input_ids'].append(input_ids) + new_batch['rejected_labels'].append(labels) + new_batch['rejected_length'].append(len(input_ids)) + + return new_batch def _drop_overlong(chosen_length: int, rejected_length: int, max_length: int):