Source code for datumaro.plugins.tf_detection_api_format.extractor
# Copyright (C) 2019-2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
import os.path as osp
import re
from collections import OrderedDict
import numpy as np
from datumaro.components.annotation import AnnotationType, Bbox, LabelCategories, Mask
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.media import ByteImage
from datumaro.util.image import decode_image, lazy_image
from datumaro.util.tf_util import import_tf as _import_tf
from .format import DetectionApiPath
tf = _import_tf()
[docs]class TfDetectionApiExtractor(SourceExtractor):
[docs] def __init__(self, path, subset=None):
assert osp.isfile(path), path
images_dir = ""
root_dir = osp.dirname(osp.abspath(path))
if osp.basename(root_dir) == DetectionApiPath.ANNOTATIONS_DIR:
root_dir = osp.dirname(root_dir)
images_dir = osp.join(root_dir, DetectionApiPath.IMAGES_DIR)
if not osp.isdir(images_dir):
images_dir = ""
if not subset:
subset = osp.splitext(osp.basename(path))[0]
super().__init__(subset=subset)
items, labels = self._parse_tfrecord_file(path, self._subset, images_dir)
self._items = items
self._categories = self._load_categories(labels)
@staticmethod
def _load_categories(labels):
label_categories = LabelCategories().from_iterable(
e[0] for e in sorted(labels.items(), key=lambda item: item[1])
)
return {AnnotationType.label: label_categories}
@classmethod
def _parse_labelmap(cls, text):
id_pattern = r"(?:id\s*:\s*(?P<id>\d+))"
name_pattern = r"(?:name\s*:\s*[\'\"](?P<name>.*?)[\'\"])"
entry_pattern = r"(\{(?:[\s\n]*(?:%(id)s|%(name)s)[\s\n]*){2}\})+" % {
"id": id_pattern,
"name": name_pattern,
}
matches = re.finditer(entry_pattern, text)
labelmap = {}
for match in matches:
label_id = match.group("id")
label_name = match.group("name")
if label_id is not None and label_name is not None:
labelmap[label_name] = int(label_id)
return labelmap
@classmethod
def _parse_tfrecord_file(cls, filepath, subset, images_dir):
dataset = tf.data.TFRecordDataset(filepath)
features = {
"image/filename": tf.io.FixedLenFeature([], tf.string),
"image/source_id": tf.io.FixedLenFeature([], tf.string),
"image/height": tf.io.FixedLenFeature([], tf.int64),
"image/width": tf.io.FixedLenFeature([], tf.int64),
"image/encoded": tf.io.FixedLenFeature([], tf.string),
"image/format": tf.io.FixedLenFeature([], tf.string),
# use varlen to avoid errors when this field is missing
"image/key/sha256": tf.io.VarLenFeature(tf.string),
# Object boxes and classes.
"image/object/bbox/xmin": tf.io.VarLenFeature(tf.float32),
"image/object/bbox/xmax": tf.io.VarLenFeature(tf.float32),
"image/object/bbox/ymin": tf.io.VarLenFeature(tf.float32),
"image/object/bbox/ymax": tf.io.VarLenFeature(tf.float32),
"image/object/class/label": tf.io.VarLenFeature(tf.int64),
"image/object/class/text": tf.io.VarLenFeature(tf.string),
"image/object/mask": tf.io.VarLenFeature(tf.string),
}
dataset_labels = OrderedDict()
labelmap_path = osp.join(osp.dirname(filepath), DetectionApiPath.LABELMAP_FILE)
if osp.exists(labelmap_path):
with open(labelmap_path, "r", encoding="utf-8") as f:
labelmap_text = f.read()
dataset_labels.update(
{label: id - 1 for label, id in cls._parse_labelmap(labelmap_text).items()}
)
dataset_items = []
for record in dataset:
parsed_record = tf.io.parse_single_example(record, features)
frame_id = parsed_record["image/source_id"].numpy().decode("utf-8")
frame_filename = parsed_record["image/filename"].numpy().decode("utf-8")
frame_height = tf.cast(parsed_record["image/height"], tf.int64).numpy().item()
frame_width = tf.cast(parsed_record["image/width"], tf.int64).numpy().item()
frame_image = parsed_record["image/encoded"].numpy()
xmins = tf.sparse.to_dense(parsed_record["image/object/bbox/xmin"]).numpy()
ymins = tf.sparse.to_dense(parsed_record["image/object/bbox/ymin"]).numpy()
xmaxs = tf.sparse.to_dense(parsed_record["image/object/bbox/xmax"]).numpy()
ymaxs = tf.sparse.to_dense(parsed_record["image/object/bbox/ymax"]).numpy()
label_ids = tf.sparse.to_dense(parsed_record["image/object/class/label"]).numpy()
labels = tf.sparse.to_dense(
parsed_record["image/object/class/text"], default_value=b""
).numpy()
masks = tf.sparse.to_dense(
parsed_record["image/object/mask"], default_value=b""
).numpy()
for label, label_id in zip(labels, label_ids):
label = label.decode("utf-8")
if not label:
continue
if label_id <= 0:
continue
if label in dataset_labels:
continue
dataset_labels[label] = label_id - 1
item_id = osp.splitext(frame_filename)[0]
annotations = []
for shape_id, shape in enumerate(np.dstack((labels, xmins, ymins, xmaxs, ymaxs))[0]):
label = shape[0].decode("utf-8")
mask = None
if len(masks) != 0:
mask = masks[shape_id]
if mask is not None:
if isinstance(mask, bytes):
mask = lazy_image(mask, decode_image)
annotations.append(Mask(image=mask, label=dataset_labels.get(label)))
else:
x = clamp(shape[1] * frame_width, 0, frame_width)
y = clamp(shape[2] * frame_height, 0, frame_height)
w = clamp(shape[3] * frame_width, 0, frame_width) - x
h = clamp(shape[4] * frame_height, 0, frame_height) - y
annotations.append(Bbox(x, y, w, h, label=dataset_labels.get(label)))
image_size = None
if frame_height and frame_width:
image_size = (frame_height, frame_width)
image_params = {}
if frame_image:
image_params["data"] = frame_image
if frame_filename:
image_params["path"] = osp.join(images_dir, frame_filename)
image = None
if image_params:
image = ByteImage(**image_params, size=image_size)
dataset_items.append(
DatasetItem(
id=item_id,
subset=subset,
media=image,
annotations=annotations,
attributes={"source_id": frame_id},
)
)
return dataset_items, dataset_labels