diff --git a/extract_features.py b/extract_features.py index 6ad3a90e009..3c3659c89f2 100644 --- a/extract_features.py +++ b/extract_features.py @@ -211,6 +211,7 @@ def main(): parser.add_argument("--do_lower_case", default=True, action='store_true', help="Whether to lower case the input text. Should be True for uncased " "models and False for cased models.") + parser.add_argument("--no_cuda", default=False, type=bool, help="Is cuda available?") parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.") parser.add_argument("--local_rank", type=int,