Skip to content

Commit

Permalink
feat: adding num_proc to load_dataset (#26326)
Browse files Browse the repository at this point in the history
* feat: adding num_proc to load_dataset

* feat: add add_num_proc for run_mlm_flax

* feat: add num_proc for bart and t5

* chorse: remove
  • Loading branch information
pphuc25 authored Sep 22, 2023
1 parent 576cd45 commit 910faa3
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 0 deletions.
6 changes: 6 additions & 0 deletions examples/flax/language-modeling/run_bart_dlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)

if "validation" not in datasets.keys():
Expand All @@ -540,13 +541,15 @@ def main():
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
else:
data_files = {}
Expand All @@ -562,6 +565,7 @@ def main():
data_files=data_files,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)

if "validation" not in datasets.keys():
Expand All @@ -571,13 +575,15 @@ def main():
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
Expand Down
6 changes: 6 additions & 0 deletions examples/flax/language-modeling/run_clm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def main():
cache_dir=model_args.cache_dir,
keep_in_memory=False,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)

if "validation" not in dataset.keys():
Expand All @@ -430,13 +431,15 @@ def main():
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
dataset["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
else:
data_files = {}
Expand All @@ -455,6 +458,7 @@ def main():
cache_dir=model_args.cache_dir,
**dataset_args,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)

if "validation" not in dataset.keys():
Expand All @@ -465,6 +469,7 @@ def main():
cache_dir=model_args.cache_dir,
**dataset_args,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
dataset["train"] = load_dataset(
extension,
Expand All @@ -473,6 +478,7 @@ def main():
cache_dir=model_args.cache_dir,
**dataset_args,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
Expand Down
6 changes: 6 additions & 0 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)

if "validation" not in datasets.keys():
Expand All @@ -467,13 +468,15 @@ def main():
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
else:
data_files = {}
Expand All @@ -489,6 +492,7 @@ def main():
data_files=data_files,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)

if "validation" not in datasets.keys():
Expand All @@ -498,13 +502,15 @@ def main():
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
Expand Down
6 changes: 6 additions & 0 deletions examples/flax/language-modeling/run_t5_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)

if "validation" not in datasets.keys():
Expand All @@ -581,13 +582,15 @@ def main():
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
else:
data_files = {}
Expand All @@ -603,6 +606,7 @@ def main():
data_files=data_files,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)

if "validation" not in datasets.keys():
Expand All @@ -612,13 +616,15 @@ def main():
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
Expand Down

0 comments on commit 910faa3

Please sign in to comment.