Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert SlimSAM checkpoints #28379

Merged
merged 8 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
# limitations under the License.
"""
Convert SAM checkpoints from the original repository.

URL: https://github.com/facebookresearch/segment-anything.

Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master.
"""
import argparse
import re
Expand All @@ -33,6 +37,47 @@
)


def get_config(model_name):
if "slimsam-50" in model_name:
vision_config = SamVisionConfig(
hidden_size=384,
mlp_dim=1536,
num_hidden_layers=12,
num_attention_heads=12,
global_attn_indexes=[2, 5, 8, 11],
)
elif "slimsam-77" in model_name:
vision_config = SamVisionConfig(
hidden_size=168,
mlp_dim=696,
num_hidden_layers=12,
num_attention_heads=12,
global_attn_indexes=[2, 5, 8, 11],
)
elif "sam_vit_b" in model_name:
vision_config = SamVisionConfig()
elif "sam_vit_l" in model_name:
vision_config = SamVisionConfig(
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
global_attn_indexes=[5, 11, 17, 23],
)
elif "sam_vit_h" in model_name:
vision_config = SamVisionConfig(
hidden_size=1280,
num_hidden_layers=32,
num_attention_heads=16,
global_attn_indexes=[7, 15, 23, 31],
)

config = SamConfig(
vision_config=vision_config,
)

return config


KEYS_TO_MODIFY_MAPPING = {
"iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
"iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
Expand Down Expand Up @@ -88,63 +133,47 @@ def replace_keys(state_dict):
return model_state_dict


def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_hub_id="ybelkada/segment-anything"):
checkpoint_path = hf_hub_download(model_hub_id, f"checkpoints/{model_name}.pth")

if "sam_vit_b" in model_name:
config = SamConfig()
elif "sam_vit_l" in model_name:
vision_config = SamVisionConfig(
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
global_attn_indexes=[5, 11, 17, 23],
)

config = SamConfig(
vision_config=vision_config,
)
elif "sam_vit_h" in model_name:
vision_config = SamVisionConfig(
hidden_size=1280,
num_hidden_layers=32,
num_attention_heads=16,
global_attn_indexes=[7, 15, 23, 31],
)

config = SamConfig(
vision_config=vision_config,
)
def convert_sam_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub):
config = get_config(model_name)

state_dict = torch.load(checkpoint_path, map_location="cpu")
state_dict = replace_keys(state_dict)

image_processor = SamImageProcessor()

processor = SamProcessor(image_processor=image_processor)
hf_model = SamModel(config)
hf_model.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"

hf_model.load_state_dict(state_dict)
hf_model = hf_model.to("cuda")
hf_model = hf_model.to(device)

img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

input_points = [[[400, 650]]]
input_points = [[[500, 375]]]
input_labels = [[1]]

inputs = processor(images=np.array(raw_image), return_tensors="pt").to("cuda")
inputs = processor(images=np.array(raw_image), return_tensors="pt").to(device)

with torch.no_grad():
output = hf_model(**inputs)
scores = output.iou_scores.squeeze()

if model_name == "sam_vit_h_4b8939":
assert scores[-1].item() == 0.579890251159668
if model_name == "sam_vit_b_01ec64":
inputs = processor(
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to(device)

with torch.no_grad():
output = hf_model(**inputs)
scores = output.iou_scores.squeeze()

elif model_name == "sam_vit_h_4b8939":
inputs = processor(
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to("cuda")
).to(device)

with torch.no_grad():
output = hf_model(**inputs)
Expand All @@ -154,7 +183,7 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h

input_boxes = ((75, 275, 1725, 850),)

inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to("cuda")
inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to(device)

with torch.no_grad():
output = hf_model(**inputs)
Expand All @@ -168,39 +197,54 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h

inputs = processor(
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to("cuda")
).to(device)

with torch.no_grad():
output = hf_model(**inputs)
scores = output.iou_scores.squeeze()

assert scores[-1].item() == 0.9936047792434692

if pytorch_dump_folder is not None:
processor.save_pretrained(pytorch_dump_folder)
hf_model.save_pretrained(pytorch_dump_folder)

if push_to_hub:
repo_id = f"nielsr/{model_name}" if "slimsam" in model_name else f"meta/{model_name}"
processor.push_to_hub(repo_id)
hf_model.push_to_hub(repo_id)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195"]
choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195", "slimsam-50-uniform", "slimsam-77-uniform"]
parser.add_argument(
"--model_name",
default="sam_vit_h_4b8939",
choices=choices,
type=str,
help="Path to hf config.json of model to convert",
help="Name of the original model to convert",
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=False,
help="Path to the original checkpoint",
)
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether to push the model and processor to the hub after converting",
)
parser.add_argument(
"--model_hub_id",
default="ybelkada/segment-anything",
choices=choices,
type=str,
help="Path to hf config.json of model to convert",
)

args = parser.parse_args()

convert_sam_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.model_hub_id)
if "slimsam" in args.model_name:
checkpoint_path = args.checkpoint_path
if checkpoint_path is None:
raise ValueError("You need to provide a checkpoint path for SlimSAM models.")
else:
checkpoint_path = hf_hub_download("ybelkada/segment-anything", f"checkpoints/{args.model_name}.pth")

convert_sam_checkpoint(args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
2 changes: 1 addition & 1 deletion utils/not_doctested.txt
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ src/transformers/models/rwkv/configuration_rwkv.py
src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py
src/transformers/models/rwkv/modeling_rwkv.py
src/transformers/models/sam/configuration_sam.py
src/transformers/models/sam/convert_sam_original_to_hf_format.py
src/transformers/models/sam/convert_sam_to_hf.py
src/transformers/models/sam/image_processing_sam.py
src/transformers/models/sam/modeling_sam.py
src/transformers/models/sam/modeling_tf_sam.py
Expand Down
Loading