Skip to content

Commit

Permalink
sftdataset filter
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyuenlp committed Feb 19, 2025
1 parent 35c8dae commit a11dd68
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions verl/utils/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self,
prompt_dict_keys=None,
response_key='response',
response_dict_keys=None,
filter_data=True,
max_length=1024,
truncation='error'):
assert truncation in ['error', 'left', 'right']
Expand All @@ -56,11 +57,12 @@ def __init__(self,
tokenizer = hf_tokenizer(tokenizer)
self.tokenizer: PreTrainedTokenizer = tokenizer

self.prompt_key = prompt_key if isinstance(prompt_key, (tuple, list)) else [prompt_key]
self.response_key = response_key if isinstance(response_key, (tuple, list)) else [response_key]
self.prompt_key = prompt_key[0] if isinstance(prompt_key, (tuple, list)) else prompt_key
self.response_key = response_key[0] if isinstance(response_key, (tuple, list)) else response_key
self.prompt_dict_keys = [] if not prompt_dict_keys else prompt_dict_keys
self.response_dict_keys = [] if not response_dict_keys else response_dict_keys

self.filter_data = filter_data
self.max_length = max_length

self._download()
Expand All @@ -84,6 +86,7 @@ def series_to_item(ls):
dataframe = pd.read_parquet(parquet_file)
dataframes.append(dataframe)
self.dataframe = pd.concat(dataframes)

self.prompts = self.dataframe[self.prompt_key]
for key in self.prompt_dict_keys:
# type(x): pandas.core.series.Series
Expand All @@ -94,15 +97,25 @@ def series_to_item(ls):
except Exception:
print(f'self.prompts={self.prompts}')
raise
self.prompts = self.prompts.tolist()

self.responses = self.dataframe[self.response_key]
for key in self.response_dict_keys:
try:
self.responses = self.responses.apply(lambda x: series_to_item(x)[key], axis=1)
except Exception:
print(f'self.responses={self.responses}')
raise
self.responses = self.responses.tolist()

if self.filter_data:
self.dataframe = pd.DataFrame({self.prompts.name: self.prompts, self.responses.name: self.responses})
self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
self.tokenizer.apply_chat_template([{'role': 'user', 'content': doc[self.prompts.name]},{'role': 'assistant', 'content': doc[self.responses.name]}], add_generation_prompt=True)) <= self.max_length,
axis=1)]
self.prompts = self.dataframe[self.prompts.name].tolist()
self.responses = self.dataframe[self.responses.name].tolist()
else:
self.prompts = self.prompts.tolist()
self.responses = self.responses.tolist()

def __len__(self):
return len(self.prompts)
Expand Down

0 comments on commit a11dd68

Please sign in to comment.