Skip to content

Commit

Permalink
add 'all' arg to download; add missing tasks to download
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Jan 17, 2025
1 parent 30820c8 commit b58d174
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
24 changes: 20 additions & 4 deletions totalsegmentator/bin/totalseg_download_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def main():
"head_glands_cavities", "head_muscles", "headneck_bones_vessels",
"headneck_muscles", "liver_vessels", "brain_structures",
"lung_nodules", "kidney_cysts", "breasts",
"thigh_shoulder_muscles", "thigh_shoulder_muscles_mr"],
"thigh_shoulder_muscles", "thigh_shoulder_muscles_mr",
"all"],
help="Task for which to download the weights", default="total")

args = parser.parse_args()
Expand Down Expand Up @@ -55,9 +56,11 @@ def main():
"lung_nodules": [913],
"kidney_cysts": [789],
"breasts": [527],
"oculomotor_muscles": [351],

"heartchambers_highres": [301],
"appendicular_bones": [304],
"appendicular_bones_mr": [855],
"tissue_types": [481],
"tissue_types_mr": [854],
"tissue_4_types": [485],
Expand All @@ -72,9 +75,22 @@ def main():
setup_totalseg()
set_config_key("statistics_disclaimer_shown", True)

for task_id in task_to_id[args.task]:
print(f"Processing {task_id}...")
download_pretrained_weights(task_id)
if args.task == "all":
# Get unique task IDs from all tasks
all_task_ids = set()
for task_ids in task_to_id.values():
if isinstance(task_ids, list):
all_task_ids.update(task_ids)
else:
all_task_ids.add(task_ids)

for task_id in sorted(all_task_ids):
print(f"Processing {task_id}...")
download_pretrained_weights(task_id)
else:
for task_id in task_to_id[args.task]:
print(f"Processing {task_id}...")
download_pretrained_weights(task_id)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions totalsegmentator/download_pretrained_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

if __name__ == "__main__":
"""
Download all pretrained weights
Download all pretrained weights (without commercial models)
"""
for task_id in [291, 292, 293, 294, 295, 297, 298, 258, 150, 260,
315, 299, 300, 850, 851, 852, 853, 775, 776, 777, 778,
779]:
779, 351, 913, 789, 527]:
download_pretrained_weights(task_id)
sleep(5)

0 comments on commit b58d174

Please sign in to comment.