diff --git a/examples/flax/language-modeling/run_bart_dlm_flax.py b/examples/flax/language-modeling/run_bart_dlm_flax.py index cce2253e9ae2..d1928dd73130 100644 --- a/examples/flax/language-modeling/run_bart_dlm_flax.py +++ b/examples/flax/language-modeling/run_bart_dlm_flax.py @@ -531,6 +531,7 @@ def main(): data_args.dataset_config_name, cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) if "validation" not in datasets.keys(): @@ -540,6 +541,7 @@ def main(): split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) datasets["train"] = load_dataset( data_args.dataset_name, @@ -547,6 +549,7 @@ def main(): split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) else: data_files = {} @@ -562,6 +565,7 @@ def main(): data_files=data_files, cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) if "validation" not in datasets.keys(): @@ -571,6 +575,7 @@ def main(): split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) datasets["train"] = load_dataset( extension, @@ -578,6 +583,7 @@ def main(): split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index b740dfcffe12..95e175d494bf 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -421,6 +421,7 @@ def main(): cache_dir=model_args.cache_dir, keep_in_memory=False, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) if "validation" not in dataset.keys(): @@ -430,6 +431,7 @@ def main(): split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) dataset["train"] = load_dataset( data_args.dataset_name, @@ -437,6 +439,7 @@ def main(): split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) else: data_files = {} @@ -455,6 +458,7 @@ def main(): cache_dir=model_args.cache_dir, **dataset_args, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) if "validation" not in dataset.keys(): @@ -465,6 +469,7 @@ def main(): cache_dir=model_args.cache_dir, **dataset_args, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) dataset["train"] = load_dataset( extension, @@ -473,6 +478,7 @@ def main(): cache_dir=model_args.cache_dir, **dataset_args, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 91d5c1c21b02..00c1bb32d099 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -458,6 +458,7 @@ def main(): data_args.dataset_config_name, cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) if "validation" not in datasets.keys(): @@ -467,6 +468,7 @@ def main(): split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) datasets["train"] = load_dataset( data_args.dataset_name, @@ -474,6 +476,7 @@ def main(): split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) else: data_files = {} @@ -489,6 +492,7 @@ def main(): data_files=data_files, cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) if "validation" not in datasets.keys(): @@ -498,6 +502,7 @@ def main(): split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) datasets["train"] = load_dataset( extension, @@ -505,6 +510,7 @@ def main(): split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 2d4f51474eb3..a4641dc21526 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -572,6 +572,7 @@ def main(): data_args.dataset_config_name, cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) if "validation" not in datasets.keys(): @@ -581,6 +582,7 @@ def main(): split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) datasets["train"] = load_dataset( data_args.dataset_name, @@ -588,6 +590,7 @@ def main(): split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) else: data_files = {} @@ -603,6 +606,7 @@ def main(): data_files=data_files, cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) if "validation" not in datasets.keys(): @@ -612,6 +616,7 @@ def main(): split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) datasets["train"] = load_dataset( extension, @@ -619,6 +624,7 @@ def main(): split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, token=model_args.token, + num_proc=data_args.preprocessing_num_workers, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html.