Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataModule 小重構 #12

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/examples/phi-3/phi-3-mini_dpo_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ data:
class_path: HFTokenizer
init_args:
path: microsoft/Phi-3-mini-128k-instruct
chat_template: phi-3
batch_size: 1
max_length: 4096
pad_to_multiple_of: 64
Expand Down
4 changes: 2 additions & 2 deletions docs/pre_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ For a complete set of parameters, please refer to [`PreTrainingDataModuleConfig`

Before training begins, the framework automatically processes the data, ensuring everything is ready before training starts. This is particularly convenient when dealing with small training dataset. However, pre-training datasets are typically large, and CPUs used during training are often limited, making this step very time-consuming.

To address this issue, you can set `pre_processed_data_path` and use many CPUs to execute `scripts/pre_process_pre_training_data.py` for pre-processing and saving the data in advance.
To address this issue, you can set `pre_processed_data_path` and use many CPUs to execute `scripts/pre_process_data.py` for pre-processing and saving the data in advance.

Remember to set `num_proc` to the desired number of CPUs to utilize.

Expand All @@ -35,7 +35,7 @@ data:
```

```bash
python scripts/pre_process_pre_training_data.py -c <CONFIG_PATH>
python scripts/pre_process_data.py -c <CONFIG_PATH>
```

## Data Sampling
Expand Down
55 changes: 55 additions & 0 deletions scripts/pre_process_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import sys
from typing import TextIO

from llm_training.data import *
from llm_training.models import *
from llm_training.overrides.cli import *


class OutputStreamRedirector:
def __init__(self, *streams: TextIO) -> None:
self._streams = streams

def write(self, s: str) -> int:
n = 0
for stream in self._streams:
n += stream.write(s)
return n

def flush(self) -> None:
for s in self._streams:
s.flush()


def main():
cli = LightningCLI(run=False)

datamodule = cli.datamodule

assert isinstance(datamodule, HFBasedDataModule)

config = datamodule.config

enable_cache = config.enable_cache
pre_processed_data_path = config.pre_processed_data_path

if not os.path.exists(pre_processed_data_path) or len(os.listdir(pre_processed_data_path)) == 0:
config.pre_processed_data_path = None
config.enable_cache = True
datamodule.setup()
datamodule.save_pre_processed_data(pre_processed_data_path)
else:
print(f'`pre_processed_data_path="{pre_processed_data_path}"` is not empty, skipping.')
datamodule.setup()

if not enable_cache:
n = datamodule.cleanup_cache_files()
print(f'Cleanup cache files: {n}')

with open(os.path.join(pre_processed_data_path, 'info.txt'), 'w') as f:
datamodule.print_dataset_info(file=OutputStreamRedirector(sys.stdout, f))


if __name__ == '__main__':
main()
76 changes: 0 additions & 76 deletions scripts/pre_process_pre_training_data.py

This file was deleted.

1 change: 1 addition & 0 deletions src/llm_training/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .base_datamodule import BaseDataModule
from .base_datamodule_config import BaseDataModuleConfig
from .dummy import *
from .hf_based import *
from .instruction_tuning import *
from .pre_training import *
from .preference_tuning import *
23 changes: 19 additions & 4 deletions src/llm_training/data/base_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from functools import partial
from typing import Mapping
from typing import Mapping, TextIO

import lightning as L
from torch.utils.data import DataLoader, Dataset
Expand All @@ -26,9 +27,7 @@ def __init__(self, config: BaseDataModuleConfig) -> None:
self.raw_dataset_dict = None
self.pre_processed_dataset_dict = None
self.dataset_dict = None

self.train_dataloader_state = {}


def load_data(self) -> DatasetDict:
raise NotImplementedError()

Expand All @@ -53,6 +52,22 @@ def save_pre_processed_data(self, path: str | None = None) -> None:
def load_pre_processed_data(self, path: str | None = None) -> None:
raise NotImplementedError()

def print_dataset_info(self, file: TextIO | None = None) -> None:
print_ = partial(print, file=file)
def print_header(header: str) -> None:
n = os.get_terminal_size().columns
m = (n - len(header) - 2) // 2
divider = '=' * m
header = f'{divider} {header} {divider}'
print_(f'{header:^{n}}', end='\n\n')

print_header('Raw Dataset')
print_(self.raw_dataset_dict, end='\n\n')
print_header('Pre-processed Dataset')
print_(self.pre_processed_dataset_dict, end='\n\n')
print_header('Final Dataset')
print_(self.dataset_dict, end='\n\n')

def _get_dataloader(self, split: str):
dataloader_class = DataLoader
dataloader_kwargs = dict(
Expand Down
3 changes: 2 additions & 1 deletion src/llm_training/data/chat_templates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

class _ChatTemplates:
def _get_path_by_name(self, name: str) -> Path:
return Path(__file__).parent.joinpath(name).with_suffix('.j2')
p = Path(__file__).parent / name
return p.with_suffix(f'{p.suffix}.j2')

def __getitem__(self, name: str) -> str:
if name not in self:
Expand Down
25 changes: 14 additions & 11 deletions src/llm_training/data/chat_templates/chatml.j2
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
{% set is_splitted = index is defined and length is defined %}
{% for message in messages %}
{% set content = message['content'] + '<|im_end|>\n' %}
{% if message['role'] != 'assistant' or not is_splitted %}
{% set content = '<|im_start|>' + message['role'] + '\n' + content %}
{% endif %}
{{- content -}}
{% endfor %}
{% if add_generation_prompt %}
{{- '<|im_start|>assistant\n' -}}
{% endif %}
{%- for message in messages %}
{{- '<|im_start|>' + message.role + '\n' }}
{%- set content = message.content + '<|im_end|>\n' %}
{%- if message.role == 'assistant' %}
{% generation %}
{{- content -}}
{% endgeneration %}
{%- else %}
{{- content }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
36 changes: 15 additions & 21 deletions src/llm_training/data/chat_templates/gemma.j2
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
{% set is_splitted = index is defined and length is defined %}
{% for message in messages %}
{% set index = index if is_splitted else loop.index0 %}
{% set length = length if is_splitted else messages|length %}
{% set role = message['role'] if message['role'] != 'assistant' else 'model' %}
{% set header = '<start_of_turn>' + role + '\n' %}
{% set content = message['content'] | trim + '<end_of_turn>\n' %}
{% if message['role'] != 'assistant' or not is_splitted %}
{% set content = header + content %}
{% endif %}
{% if index == 0 and loop.index0 == 0 %}
{% set content = bos_token + content %}
{% endif %}
{% if index == length - 1 and loop.index0 == messages|length - 1 and message['role'] == 'assistant' %}
{% set content = content + eos_token %}
{% endif %}
{{- content -}}
{% endfor %}
{% if add_generation_prompt %}
{{- '<start_of_turn>model\n' -}}
{% endif %}
{{- bos_token }}
{%- for message in messages %}
{%- set content = message.content | trim + '<end_of_turn>\n' %}
{%- set role = 'model' if message.role == 'assistant' else message.role %}
{{- '<start_of_turn>' + role + '\n' }}
{%- if message.role == 'assistant' %}
{% generation %}
{{- content }}
{% endgeneration %}
{%- else %}
{{- content }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}{{'<start_of_turn>model\n'}}
{%- endif %}
51 changes: 24 additions & 27 deletions src/llm_training/data/chat_templates/llama-2.j2
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
{% if messages[0]['role'] == 'system' %}
{% set loop_messages = messages[1:] %}
{% set system_message = messages[0]['content'] %}
{% else %}
{% set loop_messages = messages %}
{% set system_message = false %}
{% endif %}
{% if system_message != false and (loop_messages|length == 0 or loop_messages|length == 1 and loop_messages[0]['role'] != 'user') %}
{{ raise_exception('The system prompt must be passed along with a user prompt.') }}
{% endif %}
{% for message in loop_messages %}
{% if loop_messages|length > 1 and (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if loop.index0 == 0 and system_message != false %}
{% set content = '<<SYS>>\n' + system_message + '\n<</SYS>>\n\n' + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}
{{- bos_token + '[INST] ' + content.strip() + ' [/INST]' -}}
{% elif message['role'] == 'system' %}
{{- '<<SYS>>\n' + content.strip() + '\n<</SYS>>\n\n' -}}
{% elif message['role'] == 'assistant' %}
{{- ' ' + content.strip() + ' ' + eos_token -}}
{% endif %}
{% endfor %}
{%- if messages[0]['role'] == 'system' %}
{%- set loop_messages = messages[1:] %}
{%- set system_message = messages[0]['content'] %}
{%- else %}
{%- set loop_messages = messages %}
{%- set system_message = false %}
{%- endif %}
{%- for message in loop_messages %}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{- raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{%- endif %}
{%- if loop.index0 == 0 and system_message != false %}
{%- set content = '<<SYS>>\n' + system_message + '\n<</SYS>>\n\n' + message['content'] %}
{%- else %}
{%- set content = message['content'] %}
{%- endif %}
{%- if message['role'] == 'user' %}
{{- bos_token + '[INST] ' + content.strip() + ' [/INST]' }}
{%- elif message['role'] == 'assistant' %}
{%- generation %}
{{- ' ' + content.strip() + ' ' + eos_token }}
{%- endgeneration %}
{%- endif %}
{%- endfor %}
34 changes: 19 additions & 15 deletions src/llm_training/data/chat_templates/llama-3.j2
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
{% set is_splitted = index is defined and length is defined %}
{% for message in messages %}
{% set header = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' %}
{% set content = message['content'] | trim + '<|eot_id|>' %}
{% if message['role'] != 'assistant' or not is_splitted %}
{% set content = header + content %}
{% endif %}
{% if (is_splitted and index == 0 and loop.index0 == 0) or (not is_splitted and loop.index0 == 0) %}
{% set content = bos_token + content %}
{% endif %}
{{- content -}}
{% endfor %}
{% if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{% endif %}
{%- set loop_messages = messages %}
{%- for message in loop_messages %}
{%- set header = '<|start_header_id|>' + message.role + '<|end_header_id|>\n\n' %}
{%- set content = message.content | trim + '<|eot_id|>' %}
{%- if loop.index0 == 0 %}
{%- set header = bos_token + header %}
{%- endif %}
{{- header -}}
{%- if message.role == 'assistant' %}
{% generation %}
{{- content -}}
{% endgeneration %}
{%- else %}
{{- content }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
Loading