diff --git a/verl/utils/dataset/sft_dataset.py b/verl/utils/dataset/sft_dataset.py index 9c7a2966..83117b8b 100644 --- a/verl/utils/dataset/sft_dataset.py +++ b/verl/utils/dataset/sft_dataset.py @@ -43,6 +43,7 @@ def __init__(self, prompt_dict_keys=None, response_key='response', response_dict_keys=None, + filter_data=True, max_length=1024, truncation='error'): assert truncation in ['error', 'left', 'right'] @@ -56,11 +57,12 @@ def __init__(self, tokenizer = hf_tokenizer(tokenizer) self.tokenizer: PreTrainedTokenizer = tokenizer - self.prompt_key = prompt_key if isinstance(prompt_key, (tuple, list)) else [prompt_key] - self.response_key = response_key if isinstance(response_key, (tuple, list)) else [response_key] + self.prompt_key = prompt_key[0] if isinstance(prompt_key, (tuple, list)) else prompt_key + self.response_key = response_key[0] if isinstance(response_key, (tuple, list)) else response_key self.prompt_dict_keys = [] if not prompt_dict_keys else prompt_dict_keys self.response_dict_keys = [] if not response_dict_keys else response_dict_keys + self.filter_data = filter_data self.max_length = max_length self._download() @@ -84,6 +86,7 @@ def series_to_item(ls): dataframe = pd.read_parquet(parquet_file) dataframes.append(dataframe) self.dataframe = pd.concat(dataframes) + self.prompts = self.dataframe[self.prompt_key] for key in self.prompt_dict_keys: # type(x): pandas.core.series.Series @@ -94,7 +97,7 @@ def series_to_item(ls): except Exception: print(f'self.prompts={self.prompts}') raise - self.prompts = self.prompts.tolist() + self.responses = self.dataframe[self.response_key] for key in self.response_dict_keys: try: @@ -102,7 +105,17 @@ def series_to_item(ls): except Exception: print(f'self.responses={self.responses}') raise - self.responses = self.responses.tolist() + + if self.filter_data: + self.dataframe = pd.DataFrame({self.prompts.name: self.prompts, self.responses.name: self.responses}) + self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len( + self.tokenizer.apply_chat_template([{'role': 'user', 'content': doc[self.prompts.name]},{'role': 'assistant', 'content': doc[self.responses.name]}], add_generation_prompt=True)) <= self.max_length, + axis=1)] + self.prompts = self.dataframe[self.prompts.name].tolist() + self.responses = self.dataframe[self.responses.name].tolist() + else: + self.prompts = self.prompts.tolist() + self.responses = self.responses.tolist() def __len__(self): return len(self.prompts)