Skip to content

Commit

Permalink
add ABCNet
Browse files Browse the repository at this point in the history
  • Loading branch information
stan-haochen committed May 11, 2020
1 parent 8029f20 commit 30fdbaf
Show file tree
Hide file tree
Showing 40 changed files with 5,406 additions and 34 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ To date, AdelaiDet implements the following algorithms:
* [FCOS](configs/FCOS-Detection/README.md)
* [BlendMask](configs/BlendMask/README.md)
* [MEInst](configs/MEInst-InstanceSegmentation/README.md)
* [ABCNet](https://arxiv.org/abs/2002.10200) _to be released_ ([demo](https://github.com/Yuliang-Liu/bezier_curve_text_spotting))
* [ABCNet](configs/BAText/README.md)
* [SOLO](https://arxiv.org/abs/1912.04488) _to be released_ ([mmdet version](https://github.com/WXinlong/SOLO))
* [SOLOv2](https://arxiv.org/abs/2003.10152) _to be released_ ([mmdet version](https://github.com/WXinlong/SOLO))
* [DirectPose](https://arxiv.org/abs/1911.07451) _to be released_
Expand Down Expand Up @@ -45,6 +45,14 @@ Name | inf. time | box AP | mask AP | download

For more models and information, please refer to MEInst [README.md](configs/MEInst-InstanceSegmentation/README.md).

### Total_Text results with [ABCNet](https://arxiv.org/abs/2002.10200)

Name | inf. time | e2e-hmean | det-hmean | download
--- |:---------:|:---------:|:---------:|:---:
[attn_R_50](configs/BAText/TotalText/attn_R_50.yaml) | 11 FPS | 62.7 | 82.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/nyyNRdP7VBYqfgl/download)

For more models and information, please refer to ABCNet [README.md](configs/BAText/README.md).

Note that:
- Inference time for all projects is measured on a NVIDIA 1080Ti with batch size 1.
- APs are evaluated on COCO2017 val split unless specified.
Expand Down
26 changes: 26 additions & 0 deletions adet/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
_C.MODEL.MOBILENET = False
_C.MODEL.BACKBONE.ANTI_ALIAS = False
_C.MODEL.RESNETS.DEFORM_INTERVAL = 1
_C.INPUT.HFLIP = True
_C.INPUT.CROP.CROP_INSTANCE = True

# ---------------------------------------------------------------------------- #
# FCOS Head
Expand Down Expand Up @@ -73,6 +75,23 @@
# Options: FrozenBN, GN, "SyncBN", "BN"
_C.MODEL.DLA.NORM = "FrozenBN"

# ---------------------------------------------------------------------------- #
# BAText Options
# ---------------------------------------------------------------------------- #
_C.MODEL.BATEXT = CN()
_C.MODEL.BATEXT.VOC_SIZE = 96
_C.MODEL.BATEXT.NUM_CHARS = 25
_C.MODEL.BATEXT.POOLER_RESOLUTION = (8, 32)
_C.MODEL.BATEXT.IN_FEATURES = ["p2", "p3", "p4"]
_C.MODEL.BATEXT.POOLER_SCALES = (0.25, 0.125, 0.0625)
_C.MODEL.BATEXT.SAMPLING_RATIO = 1
_C.MODEL.BATEXT.CONV_DIM = 256
_C.MODEL.BATEXT.NUM_CONV = 2
_C.MODEL.BATEXT.RECOGNITION_LOSS = "ctc"
_C.MODEL.BATEXT.RECOGNIZER = "attn"
_C.MODEL.BATEXT.CANONICAL_SIZE = 96 # largest min_size for level 3 (stride=8)
_C.MODEL.BATEXT.TEST_CONFIDENCE_THRESHOLD = 0.7 # [0.0 - 1.0]

# ---------------------------------------------------------------------------- #
# BlendMask Options
# ---------------------------------------------------------------------------- #
Expand Down Expand Up @@ -180,3 +199,10 @@
_C.MODEL.MEInst.GCN_KERNEL_SIZE = 9
# Whether to compute loss on original mask (binary mask).
_C.MODEL.MEInst.LOSS_ON_MASK = False

# ---------------------------------------------------------------------------- #
# TOP Module Options
# ---------------------------------------------------------------------------- #
_C.MODEL.TOP_MODULE = CN()
_C.MODEL.TOP_MODULE.NAME = "conv"
_C.MODEL.TOP_MODULE.DIM = 16
28 changes: 26 additions & 2 deletions adet/data/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,48 @@
from detectron2.data.datasets.register_coco import register_coco_instances
from detectron2.data.datasets.builtin_meta import _get_builtin_metadata

from .datasets.text import register_text_instances

# register plane reconstruction

_PREDEFINED_SPLITS_PIC = {
"pic_person_train": ("pic/image/train", "pic/annotations/train_person.json"),
"pic_person_val": ("pic/image/val", "pic/annotations/val_person.json"),
}

metadata = {
metadata_pic = {
"thing_classes": ["person"]
}

_PREDEFINED_SPLITS_TEXT = {
"totaltext_train": ("totaltext/train_images", "totaltext/train.json"),
"totaltext_val": ("totaltext/test_images", "totaltext/test.json"),
"ctw1500_word_train": ("CTW1500/ctwtrain_text_image", "CTW1500/annotations/train_ctw1500_maxlen100_v2.json"),
"ctw1500_word_test": ("CTW1500/ctwtest_text_image","CTW1500/annotations/test_ctw1500_maxlen100.json"),
"syntext1_train": ("syntext1/syntext_word_eng", "syntext1/annotations/becan_syn_word_maxlen25.json"),
"syntext2_train": ("syntext2/emcs_imgs", "syntext2/annotations/ecms_v1_maxlen25.json"),
"mltbezier_word_train": ("extract/MLT_train_images","extract/annotations/train.json"),
}

metadata_text = {
"thing_classes": ["text"]
}


def register_all_coco(root="datasets"):
for key, (image_root, json_file) in _PREDEFINED_SPLITS_PIC.items():
# Assume pre-defined datasets live in `./datasets`.
register_coco_instances(
key,
metadata,
metadata_pic,
os.path.join(root, json_file) if "://" not in json_file else json_file,
os.path.join(root, image_root),
)
for key, (image_root, json_file) in _PREDEFINED_SPLITS_TEXT.items():
# Assume pre-defined datasets live in `./datasets`.
register_text_instances(
key,
metadata_text,
os.path.join(root, json_file) if "://" not in json_file else json_file,
os.path.join(root, image_root),
)
Expand Down
20 changes: 16 additions & 4 deletions adet/data/dataset_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T

from .detection_utils import (
build_transform_gen,
transform_instance_annotations,
annotations_to_instances,
gen_crop_transform_with_instance,
)

"""
This file contains the default mapping that's applied to "dataset dicts".
"""
Expand All @@ -24,9 +31,13 @@ class DatasetMapperWithBasis(DatasetMapper):
def __init__(self, cfg, is_train=True):
super().__init__(cfg, is_train)

# rebuild transform gen
self.tfm_gens = build_transform_gen(cfg, is_train)

# fmt: off
self.basis_loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON
self.ann_set = cfg.MODEL.BASIS_MODULE.ANN_SET
self.crop_box = cfg.INPUT.CROP.CROP_INSTANCE
# fmt: on

def __call__(self, dataset_dict):
Expand Down Expand Up @@ -64,10 +75,11 @@ def __call__(self, dataset_dict):
# Crop around an instance if there are instances in the image.
# USER: Remove if you don't use cropping
if self.crop_gen:
crop_tfm = utils.gen_crop_transform_with_instance(
crop_tfm = gen_crop_transform_with_instance(
self.crop_gen.get_crop_size(image.shape[:2]),
image.shape[:2],
np.random.choice(dataset_dict["annotations"]),
dataset_dict["annotations"],
crop_box=self.crop_box,
)
image = crop_tfm.apply_image(image)
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
Expand Down Expand Up @@ -104,13 +116,13 @@ def __call__(self, dataset_dict):

# USER: Implement additional transformations if you have other types of data
annos = [
utils.transform_instance_annotations(
transform_instance_annotations(
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
)
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = utils.annotations_to_instances(
instances = annotations_to_instances(
annos, image_shape, mask_format=self.mask_format
)
# Create a tight bounding box from masks, useful when image is cropped
Expand Down
203 changes: 203 additions & 0 deletions adet/data/datasets/text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import contextlib
import io
import logging
import os
from fvcore.common.timer import Timer
from fvcore.common.file_io import PathManager

from detectron2.structures import BoxMode

from detectron2.data import DatasetCatalog, MetadataCatalog

"""
This file contains functions to parse COCO-format text annotations into dicts in "Detectron2 format".
"""


logger = logging.getLogger(__name__)

__all__ = ["load_text_json", "register_text_instances"]


def register_text_instances(name, metadata, json_file, image_root):
"""
Register a dataset in json annotation format for text detection and recognition.
Args:
name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train".
metadata (dict): extra metadata associated with this dataset. It can be an empty dict.
json_file (str): path to the json instance annotation file.
image_root (str or path-like): directory which contains all the images.
"""
DatasetCatalog.register(name, lambda: load_text_json(json_file, image_root, name))
MetadataCatalog.get(name).set(
json_file=json_file, image_root=image_root, evaluator_type="text", **metadata
)


def load_text_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
"""
Load a json file with totaltext annotation format.
Currently supports text detection and recognition.
Args:
json_file (str): full path to the json file in totaltext annotation format.
image_root (str or path-like): the directory where the images in this json file exists.
dataset_name (str): the name of the dataset (e.g., coco_2017_train).
If provided, this function will also put "thing_classes" into
the metadata associated with this dataset.
extra_annotation_keys (list[str]): list of per-annotation keys that should also be
loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
"category_id", "segmentation"). The values for these keys will be returned as-is.
For example, the densepose annotations are loaded in this way.
Returns:
list[dict]: a list of dicts in Detectron2 standard dataset dicts format. (See
`Using Custom Datasets </tutorials/datasets.html>`_ )
Notes:
1. This function does not read the image files.
The results do not have the "image" field.
"""
from pycocotools.coco import COCO

timer = Timer()
json_file = PathManager.get_local_path(json_file)
with contextlib.redirect_stdout(io.StringIO()):
coco_api = COCO(json_file)
if timer.seconds() > 1:
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))

id_map = None
if dataset_name is not None:
meta = MetadataCatalog.get(dataset_name)
cat_ids = sorted(coco_api.getCatIds())
cats = coco_api.loadCats(cat_ids)
# The categories in a custom json file may not be sorted.
thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
meta.thing_classes = thing_classes

# In COCO, certain category ids are artificially removed,
# and by convention they are always ignored.
# We deal with COCO's id issue and translate
# the category ids to contiguous ids in [0, 80).

# It works by looking at the "categories" field in the json, therefore
# if users' own json also have incontiguous ids, we'll
# apply this mapping as well but print a warning.
if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
if "coco" not in dataset_name:
logger.warning(
"""
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
"""
)
id_map = {v: i for i, v in enumerate(cat_ids)}
meta.thing_dataset_id_to_contiguous_id = id_map

# sort indices for reproducible results
img_ids = sorted(coco_api.imgs.keys())
# imgs is a list of dicts, each looks something like:
# {'license': 4,
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
# 'file_name': 'COCO_val2014_000000001268.jpg',
# 'height': 427,
# 'width': 640,
# 'date_captured': '2013-11-17 05:57:24',
# 'id': 1268}
imgs = coco_api.loadImgs(img_ids)
# anns is a list[list[dict]], where each dict is an annotation
# record for an object. The inner list enumerates the objects in an image
# and the outer list enumerates over images. Example of anns[0]:
# [{'segmentation': [[192.81,
# 247.09,
# ...
# 219.03,
# 249.06]],
# 'area': 1035.749,
# 'rec': [84, 72, ... 96],
# 'bezier_pts': [169.0, 425.0, ..., ]
# 'iscrowd': 0,
# 'image_id': 1268,
# 'bbox': [192.81, 224.8, 74.73, 33.43],
# 'category_id': 16,
# 'id': 42986},
# ...]
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]

if "minival" not in json_file:
# The popular valminusminival & minival annotations for COCO2014 contain this bug.
# However the ratio of buggy annotations there is tiny and does not affect accuracy.
# Therefore we explicitly white-list them.
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
json_file
)

imgs_anns = list(zip(imgs, anns))

logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))

dataset_dicts = []

ann_keys = ["iscrowd", "bbox", "rec", "category_id"] + (extra_annotation_keys or [])

num_instances_without_valid_segmentation = 0

for (img_dict, anno_dict_list) in imgs_anns:
record = {}
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
record["height"] = img_dict["height"]
record["width"] = img_dict["width"]
image_id = record["image_id"] = img_dict["id"]

objs = []
for anno in anno_dict_list:
# Check that the image_id in this annotation is the same as
# the image_id we're looking at.
# This fails only when the data parsing logic or the annotation file is buggy.

# The original COCO valminusminival2014 & minival2014 annotation files
# actually contains bugs that, together with certain ways of using COCO API,
# can trigger this assertion.
assert anno["image_id"] == image_id

assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'

obj = {key: anno[key] for key in ann_keys if key in anno}

segm = anno.get("segmentation", None)
if segm: # either list[list[float]] or dict(RLE)
if not isinstance(segm, dict):
# filter out invalid polygons (< 3 points)
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
if len(segm) == 0:
num_instances_without_valid_segmentation += 1
continue # ignore this instance
obj["segmentation"] = segm

bezierpts = anno.get("bezier_pts", None)
# Bezier Points are the control points for BezierAlign Text recognition (BAText)
if bezierpts: # list[float]
obj["beziers"] = bezierpts

text = anno.get("rec", None)
if text:
obj["text"] = text

obj["bbox_mode"] = BoxMode.XYWH_ABS
if id_map:
obj["category_id"] = id_map[obj["category_id"]]
objs.append(obj)
record["annotations"] = objs
dataset_dicts.append(record)

if num_instances_without_valid_segmentation > 0:
logger.warning(
"Filtered out {} instances without valid segmentation. "
"There might be issues in your dataset generation process.".format(
num_instances_without_valid_segmentation
)
)
return dataset_dicts
Loading

0 comments on commit 30fdbaf

Please sign in to comment.