diff --git a/totalsegmentator/bin/totalseg_download_weights.py b/totalsegmentator/bin/totalseg_download_weights.py index 18a191ae2..430c0f52d 100644 --- a/totalsegmentator/bin/totalseg_download_weights.py +++ b/totalsegmentator/bin/totalseg_download_weights.py @@ -27,7 +27,8 @@ def main(): "head_glands_cavities", "head_muscles", "headneck_bones_vessels", "headneck_muscles", "liver_vessels", "brain_structures", "lung_nodules", "kidney_cysts", "breasts", - "thigh_shoulder_muscles", "thigh_shoulder_muscles_mr"], + "thigh_shoulder_muscles", "thigh_shoulder_muscles_mr", + "all"], help="Task for which to download the weights", default="total") args = parser.parse_args() @@ -55,9 +56,11 @@ def main(): "lung_nodules": [913], "kidney_cysts": [789], "breasts": [527], + "oculomotor_muscles": [351], "heartchambers_highres": [301], "appendicular_bones": [304], + "appendicular_bones_mr": [855], "tissue_types": [481], "tissue_types_mr": [854], "tissue_4_types": [485], @@ -72,9 +75,22 @@ def main(): setup_totalseg() set_config_key("statistics_disclaimer_shown", True) - for task_id in task_to_id[args.task]: - print(f"Processing {task_id}...") - download_pretrained_weights(task_id) + if args.task == "all": + # Get unique task IDs from all tasks + all_task_ids = set() + for task_ids in task_to_id.values(): + if isinstance(task_ids, list): + all_task_ids.update(task_ids) + else: + all_task_ids.add(task_ids) + + for task_id in sorted(all_task_ids): + print(f"Processing {task_id}...") + download_pretrained_weights(task_id) + else: + for task_id in task_to_id[args.task]: + print(f"Processing {task_id}...") + download_pretrained_weights(task_id) if __name__ == "__main__": diff --git a/totalsegmentator/download_pretrained_weights.py b/totalsegmentator/download_pretrained_weights.py index 885ae3b52..7eb4475a4 100644 --- a/totalsegmentator/download_pretrained_weights.py +++ b/totalsegmentator/download_pretrained_weights.py @@ -4,10 +4,10 @@ if __name__ == "__main__": """ - Download all pretrained weights + Download all pretrained weights (without commercial models) """ for task_id in [291, 292, 293, 294, 295, 297, 298, 258, 150, 260, 315, 299, 300, 850, 851, 852, 853, 775, 776, 777, 778, - 779]: + 779, 351, 913, 789, 527]: download_pretrained_weights(task_id) sleep(5)