Skip to content

Commit

Permalink
Merge pull request huggingface#22 from Superb-AI-Suite/process_evalua…
Browse files Browse the repository at this point in the history
…tion

modifications on eval & semantics extraction & label_maps & collate_fn
  • Loading branch information
SangbumChoi authored Sep 16, 2024
2 parents 49f5036 + 3af64be commit 698633f
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Finetuning 🤗 Transformers model for object detection with Accelerate."""

import argparse
import json
import logging
import math
import os
Expand All @@ -33,6 +34,7 @@
from datasets import load_dataset
from huggingface_hub import HfApi
from torch.utils.data import DataLoader
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from tqdm.auto import tqdm

import transformers
Expand Down Expand Up @@ -136,19 +138,36 @@ def convert_zero_shot_to_coco_format(predictions, label2id):
"""
# convert center to corners format
torch_label = []
for prediction in predictions:
scores = prediction["scores"]
device = scores.device
torch_label = []
labels = prediction["labels"]
for label in labels:
label = label.split(".")[0].split(" ")[0].replace(" ", "")
if label in label2id:
torch_label.append(label2id[label])
else:
# Give background class
torch_label.append(0)
prediction["labels"] = torch.Tensor(torch_label).to(dtype=torch.int32).to(device)
boxes = prediction["boxes"]

cnt_lst = []
for i, label in enumerate(labels):
label = label.split(" ")

cnt = 0
for lab in label:
if lab in label2id:
torch_label.append(label2id[lab])
cnt += 1
else:
for k,v in label2id.items():
if k.startswith(lab):
torch_label.append(v)
cnt += 1
break

cnt_lst.append(cnt)

if cnt_lst:
box_lst = []
for cnt, box in zip(cnt_lst, boxes):
box_lst.append(torch.stack([box] * cnt))
prediction["boxes"] = torch.cat(box_lst)

prediction["labels"] = torch.Tensor(torch_label).to(dtype=torch.int32).to(boxes.device)

return predictions

Expand Down Expand Up @@ -285,62 +304,64 @@ def augment_and_transform_batch(
batch_index = []
for n, annotation in enumerate(annotations):
for anno in annotation["annotations"]:
x, y, w, h = anno["bbox"]
width = int(x + w) - int(x)
height = int(y + h) - int(y)
if width < 1 or height < 1:
continue

bbox.append(anno["bbox"])
category.append(anno["category_id"])
batch_index.append(n)

bool_category = get_random_unique_indices(category)
bool_category = group_by_index(bool_category, batch_index)
bbox = group_by_index(bbox, batch_index)

for image, bool, box in zip(images, bool_category, bbox):
bboxes = []
for _bool, _box in zip(bool, box):
if _bool:
bboxes.append(_box)
semantics.extend(crop_bboxes(image, bboxes))

try:
# Apply the image processor transformations: resizing, rescaling, normalization
result = processor(
images=images,
text=text,
semantics=semantics,
images_kwargs={"annotations": annotations},
return_tensors="pt",
)
except Exception:
print(Exception)
torch.save(annotations, "debug_annotations.pth")
torch.save(images, "debug_images.pth")
torch.save(semantics, "debug_semantics.pth")
for image in images:
print("image.shape", image.shape)
print(annotations)

result = processor(
images=images, text=text, semantics=None, images_kwargs={"annotations": annotations}, return_tensors="pt"
)
if bbox:
bool_category = get_random_unique_indices(category)
bool_category = group_by_index(bool_category, batch_index)
bbox = group_by_index(bbox, batch_index)

for image, bool, box in zip(images, bool_category, bbox):
bboxes = []
for _bool, _box in zip(bool, box):
if _bool:
bboxes.append(_box)
semantics.extend(crop_bboxes(image, bboxes))
else:
# semantics can be empty due to random crop
semantics = None

# Apply the image processor transformations: resizing, rescaling, normalization
result = processor(
images=images, text=text, semantics=semantics, images_kwargs={"annotations": annotations}, return_tensors="pt"
)

if not return_pixel_mask:
result.pop("pixel_mask", None)

return result


def collate_fn(batch: List[BatchFeature]) -> Mapping[str, Union[torch.Tensor, List[Any]]]:
def collate_fn(batch: List[BatchFeature], use_semantics=False) -> Mapping[str, Union[torch.Tensor, List[Any]]]:
data = {}
data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
data["labels"] = [x["labels"] for x in batch]

if "input_ids" in batch[0]:
data["input_ids"] = torch.stack([x["input_ids"] for x in batch])
data["token_type_ids"] = torch.stack([x["token_type_ids"] for x in batch])
if "input_semantics" in batch[0]:
data["input_semantics"] = torch.stack([x["input_semantics"] for x in batch])
if use_semantics:
if "input_semantics" in batch[0]:
data["input_semantics"] = torch.stack([x["input_semantics"] for x in batch])
# Either use text prompt or visual prompt. Not both.
if random.random() < 0.5:
data["input_ids"] = None
data["attention_mask"] = None
else:
data["input_semantics"] = None
if "pixel_mask" in batch[0]:
data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch])
if "attention_mask" in batch[0]:
data["attention_mask"] = torch.stack([x["attention_mask"] for x in batch])

return data


Expand All @@ -366,7 +387,7 @@ def evaluation_loop(
label2id: Mapping[str, int],
) -> dict:
model.eval()
metrics = {}
metric = MeanAveragePrecision(box_format="xyxy", class_metrics=True)

for step, batch in enumerate(tqdm(dataloader, disable=not accelerator.is_local_main_process)):
with torch.no_grad():
Expand All @@ -378,24 +399,42 @@ def evaluation_loop(
# processor convert boxes from YOLO format to Pascal VOC format
# ([x_min, y_min, x_max, y_max] in absolute coordinates)
image_size = torch.stack([example["orig_size"] for example in batch["labels"]], dim=0)
predictions = processor.post_process_grounded_object_detection_v2(
outputs, box_threshold=0.15, target_sizes=image_size
predictions = processor.post_process_grounded_object_detection(
outputs,
batch["input_ids"] if "input_ids" in batch else None,
batch["input_semantics"] if "input_semantics" in batch else None,
box_threshold=0.0,
multimodal_threshold=0.0,
target_sizes=image_size,
)
predictions = nested_to_cpu(predictions)
metrics["val_loss"] = (
(metrics["val_loss"] + outputs.loss.detach().float())
if "val_loss" in metrics
else outputs.loss.detach().float()
)
for name, value in outputs.loss_dict.items():
if name not in metrics:
metrics[f"val_{name}"] = value.detach().float()
else:
metrics[f"val_{name}"] += value.detach().float()

# normalize
for name in metrics:
metrics[name] /= step
predictions = convert_zero_shot_to_coco_format(predictions, {k.lower(): v for k, v in label2id.items()})

# 2. Collect ground truth boxes in the same format for metric computation
# Do the same, convert YOLO boxes to Pascal VOC format
target = []
for label in batch["labels"]:
label = nested_to_cpu(label)
boxes = convert_bbox_yolo_to_pascal(label["boxes"], label["orig_size"])
labels = label["class_labels"]
target.append({"boxes": boxes, "labels": labels})
# print(target[0]["labels"])
metric.update(predictions, target)

metric.to(accelerator.device)
metrics = metric.compute()

# Replace list of per class metrics with separate metric for each class
classes = metrics.pop("classes")
map_per_class = metrics.pop("map_per_class")
mar_100_per_class = metrics.pop("mar_100_per_class")
for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
class_name = id2label[class_id.item()]
metrics[f"map_{class_name}"] = class_map
metrics[f"mar_100_{class_name}"] = class_mar

# Convert metrics to float
metrics = {k: round(v.item(), 4) for k, v in metrics.items()}

return metrics

Expand Down Expand Up @@ -738,30 +777,37 @@ def main():
processor=processor,
id2label=id2label,
label2id=label2id,
random_text_prompt=True,
random_text_prompt=False, # for class-wise ap computation during evaluation
)

with accelerator.main_process_first():
train_dataset = dataset["train"].with_transform(train_transform_batch)
valid_dataset = dataset["validation"].with_transform(validation_transform_batch)
test_dataset = dataset["test"].with_transform(validation_transform_batch)

dataloader_common_args = {
"num_workers": args.dataloader_num_workers,
"collate_fn": collate_fn,
}
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
drop_last=True,
batch_size=args.per_device_train_batch_size,
**dataloader_common_args,
num_workers=args.dataloader_num_workers,
collate_fn=partial(collate_fn, use_semantics=True),
)
valid_dataloader = DataLoader(
valid_dataset, shuffle=False, batch_size=args.per_device_eval_batch_size, **dataloader_common_args
valid_dataset,
shuffle=False,
batch_size=args.per_device_eval_batch_size,
num_workers=args.dataloader_num_workers,
# use_semantics can be set to True when evaluation_loop can incorpoate visual prompt in near future
collate_fn=partial(collate_fn, use_semantics=False),
)
test_dataloader = DataLoader(
test_dataset, shuffle=False, batch_size=args.per_device_eval_batch_size, **dataloader_common_args
test_dataset,
shuffle=False,
batch_size=args.per_device_eval_batch_size,
num_workers=args.dataloader_num_workers,
# use_semantics can be set to True when evaluation_loop can incorpoate visual prompt in near future
collate_fn=partial(collate_fn, use_semantics=False),
)

# ------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -947,7 +993,7 @@ def main():
if args.with_tracking:
accelerator.log(
{
"total_train_loss": total_loss / len(train_dataloader),
"total_train_loss": total_loss.item() / len(train_dataloader),
**metrics,
"epoch": epoch,
"step": completed_steps,
Expand Down
62 changes: 42 additions & 20 deletions src/transformers/models/grounding_dino2/modeling_grounding_dino2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,14 +2150,14 @@ def generate_masks_with_special_tokens_and_transfer_map(input_ids: torch.LongTen
def generate_masks_with_input_semantics(batch_size: int, input_semantics: torch.FloatTensor) -> Tuple[Tensor, Tensor]:
"""Generate attention mask between each pair of special tokens and positional ids.
Args:
input_semantics (`torch.FloatTensor` of shape `(semantic_length, 3, width, height)`):
input_semantics (`torch.FloatTensor` of shape `(batch_size, semantic_length, 3, height, width)`):
Indices of input semantics after image processor.
Returns:
`tuple(torch.Tensor)` comprising attention mask between each special tokens and position_ids:
- **attention_mask** (`torch.BoolTensor` of shape `(batch_size, sequence_length)`)
- **position_ids** (`torch.LongTensor` of shape `(batch_size, sequence_length)`)
"""
_, num_semantic, _, width, height = input_semantics.shape
_, num_semantic, _, height, width = input_semantics.shape

# generate attention mask and positional ids
semantic_self_attention_masks = (
Expand Down Expand Up @@ -2503,7 +2503,7 @@ def forward(
semantic_features = semantic_features.transpose(0, 1).expand(batch_size, -1, -1)

if input_ids is not None and input_semantics is not None:
# Multimodal_features are shape of (batch_size, text_length + semantic_lenght, hidden_dim)
# Multimodal_features are shape of (batch_size, text_length + semantic_length, hidden_dim)
multimodal_features = torch.cat([text_features, semantic_features], dim=1)

text_features_dim, semantic_features_dim = text_features.shape[1], semantic_features.shape[1]
Expand Down Expand Up @@ -3142,34 +3142,56 @@ def forward(self, outputs, targets):
return losses


def build_label_maps(logits, input_ids):
def build_label_maps(logits, input_ids, input_semantics):
"""
Computes a mapping between the tokens associated with the prompt labels in the logit space with shape `(batch_size, num_labels, hidden_size)`
where `num_labels` is defined by the number of classes in the input prompt.
For instance, given the prompt "fish. shark." we get input_ids = [ 101, 3869, 1012, 11420, 1012, 102].
This function will return a mapping for each of the prompt tokens (i.e. tokens associated with "fish" and "shark")
indicating their position in the logit space.
"""
hidden_size = logits.shape[-1]
# Add [PAD] token to the list of special tokens
delimiter_tokens = torch.tensor(SPECIAL_TOKENS + [0], device=input_ids.device)

delimiter_token_masks = torch.isin(input_ids, delimiter_tokens)
label_maps = ()
# Define where is delimter token
for delimiter_token_mask in delimiter_token_masks:
if input_ids is not None:
# Add [PAD] token to the list of special tokens
delimiter_tokens = torch.tensor(SPECIAL_TOKENS + [0], device=input_ids.device)

delimiter_token_masks = torch.isin(input_ids, delimiter_tokens)
# Define where is delimter token
for delimiter_token_mask in delimiter_token_masks:
label_map_within_batch = []
delimiter_indices = torch.where(delimiter_token_mask)[0]
for i in range(len(delimiter_indices) - 1):
start = delimiter_indices[i]
end = delimiter_indices[i + 1]
if end - start > 1:
label_map = torch.zeros(hidden_size, device=input_ids.device)
label_map[start + 1 : end] = 1
label_map_within_batch.append(label_map)

## in case when text prompt and visual prompt are given together
# if input_semantics is not None:
# num_ids = input_ids.shape[-1]
# num_semantics = input_semantics.shape[1]
# for i in range(num_ids, num_ids + num_semantics):
# label_map = torch.zeros(hidden_size, device=input_semantics.device)
# label_map[i] = 1
# label_map_within_batch.append(label_map)

label_maps += (torch.stack(label_map_within_batch),)
elif input_semantics is not None:
label_map_within_batch = []
delimiter_indices = torch.where(delimiter_token_mask)[0]
for i in range(len(delimiter_indices) - 1):
start = delimiter_indices[i]
end = delimiter_indices[i + 1]
if end - start > 1:
label_map = torch.zeros(hidden_size, device=input_ids.device)
label_map[start + 1 : end] = 1
label_map_within_batch.append(label_map)
num_semantics = input_semantics.shape[1]
for i in range(num_semantics):
label_map = torch.zeros(hidden_size, device=input_semantics.device)
label_map[i] = 1
label_map_within_batch.append(label_map)

label_maps += (torch.stack(label_map_within_batch),)
label_maps += (torch.stack(label_map_within_batch),) * len(input_semantics)
else:
raise ValueError(f"input_ids : {input_ids}, input_semantics : {input_semantics} can't be both None")

return label_maps

Expand Down Expand Up @@ -3381,7 +3403,7 @@ def forward(
losses=losses,
)
criterion.to(self.device)
label_maps = build_label_maps(logits, input_ids)
label_maps = build_label_maps(logits, input_ids, input_semantics)
multimodal_mask = build_multimodal_mask(logits, attention_mask)
# Third: compute the losses, based on outputs and labels
outputs_loss = {}
Expand Down
Loading

0 comments on commit 698633f

Please sign in to comment.