# Copyright (C) 2021-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import itertools
import logging as log
import os.path as osp
from types import SimpleNamespace as namespace
from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
import attrs
from attrs import field, frozen
from datumaro.components.annotation import AnnotationType, Bbox, Label, LabelCategories
from datumaro.components.extractor import CategoriesInfo, DatasetItem, IExtractor
from datumaro.components.media import ByteImage, Image, MediaElement
from datumaro.util.tf_util import import_tf
try:
tf = import_tf()
import tensorflow_datasets as tfds
except ImportError:
log.debug(
"Unable to import TensorFlow or TensorFlow Datasets. "
"Dataset downloading via TFDS is disabled."
)
TFDS_EXTRACTOR_AVAILABLE = False
else:
TFDS_EXTRACTOR_AVAILABLE = True
@frozen
class _TfdsAdapter:
category_transformers: Sequence[
Callable[[tfds.core.DatasetBuilder, CategoriesInfo, namespace], None]
]
data_transformers: Sequence[Callable[[Any, DatasetItem, namespace], None]]
id_generator: Callable[[Any], str] = field(default=None, kw_only=True)
metadata: TfdsDatasetMetadata
def transform_categories(
self,
tfds_builder: tfds.core.DatasetBuilder,
categories: CategoriesInfo,
state: namespace,
) -> None:
for t in self.category_transformers:
t(tfds_builder, categories, state)
def transform_data(
self,
tfds_example: Any,
item: DatasetItem,
state: namespace,
) -> None:
for t in self.data_transformers:
t(tfds_example, item, state)
_FeaturePath = Union[str, Tuple[str, ...]]
def _resolve_feature_path(
feature_path: _FeaturePath,
root: tfds.features.FeaturesDict,
) -> tfds.features.FeatureConnector:
if isinstance(feature_path, str):
return root[feature_path]
feature_connector = root
for segment in feature_path:
assert isinstance(
feature_connector,
(
tfds.features.FeaturesDict,
tfds.features.Sequence,
),
)
if isinstance(feature_connector, tfds.features.Sequence):
assert segment == "feature"
feature_connector = feature_connector.feature
else:
feature_connector = feature_connector[segment]
return feature_connector
@frozen
class _SetLabelCategoriesFromClassLabelFeature:
feature_path: _FeaturePath
def __call__(
self,
tfds_builder: tfds.core.DatasetBuilder,
categories: CategoriesInfo,
state: namespace,
) -> None:
assert AnnotationType.label not in categories
feature_connector = _resolve_feature_path(self.feature_path, tfds_builder.info.features)
assert isinstance(feature_connector, tfds.features.ClassLabel)
categories[AnnotationType.label] = LabelCategories.from_iterable(feature_connector.names)
@frozen
class _SetImageFromImageFeature:
feature_name: str
filename_feature_name: Optional[str] = field(default=None)
def __call__(
self,
tfds_example: Any,
item: DatasetItem,
state: namespace,
) -> None:
if self.filename_feature_name:
filename = tfds_example[self.filename_feature_name].numpy().decode("UTF-8")
else:
filename = None
item.media = ByteImage(data=tfds_example[self.feature_name].numpy(), path=filename)
@frozen
class _AddLabelFromClassLabelFeature:
feature_name: str
def __call__(
self,
tfds_example: Any,
item: DatasetItem,
state: namespace,
) -> None:
item.annotations.append(
Label(tfds_example[self.feature_name].numpy()),
)
@frozen
class _AttributeMemberMapping:
member_name: str
attribute_name: str = field()
value_converter: Optional[Callable[[Any, namespace], Any]] = None
@attribute_name.default
def _attribute_name_default(self):
return self.member_name
@frozen
class _AddObjectsFromFeature:
feature_name: str
bbox_member: str
label_member: Optional[str] = field(default=None, kw_only=True)
attribute_members: Tuple[_AttributeMemberMapping, ...] = field(
default=(),
kw_only=True,
converter=lambda values: tuple(
value if isinstance(value, _AttributeMemberMapping) else _AttributeMemberMapping(value)
for value in values
),
)
def __call__(
self,
tfds_example: Any,
item: DatasetItem,
state: namespace,
) -> None:
tfds_objects = tfds_example[self.feature_name]
tfds_bboxes = tfds_objects[self.bbox_member]
num_objects = tfds_bboxes.shape[0]
tfds_labels = None
if self.label_member is not None:
tfds_labels = tfds_objects[self.label_member]
assert tfds_labels.shape[0] == num_objects
for am_mapping in self.attribute_members:
assert tfds_objects[am_mapping.member_name].shape[0] == num_objects
for i in range(num_objects):
norm_ymin, norm_xmin, norm_ymax, norm_xmax = tfds_bboxes[i].numpy()
new_bbox = Bbox(
x=norm_xmin * item.media.size[1],
y=norm_ymin * item.media.size[0],
w=(norm_xmax - norm_xmin) * item.media.size[1],
h=(norm_ymax - norm_ymin) * item.media.size[0],
)
if tfds_labels is not None:
new_bbox.label = tfds_labels[i].numpy()
for am_mapping in self.attribute_members:
attr_value = tfds_objects[am_mapping.member_name][i].numpy()
if am_mapping.value_converter:
attr_value = am_mapping.value_converter(attr_value, state)
new_bbox.attributes[am_mapping.attribute_name] = attr_value
item.annotations.append(new_bbox)
@frozen
class _SetAttributeFromFeature:
feature_name: str
attribute_name: str
def __call__(
self,
tfds_example: Any,
item: DatasetItem,
state: namespace,
) -> None:
item.attributes[self.attribute_name] = tfds_example[self.feature_name].numpy()
@frozen
class _GenerateIdFromTextFeature:
feature_name: str
def __call__(self, tfds_example: Any) -> str:
return tfds_example[self.feature_name].numpy().decode("UTF-8")
@frozen
class _GenerateIdFromFilenameFeature:
feature_name: str
def __call__(self, tfds_example: Any) -> str:
file_name = tfds_example[self.feature_name].numpy().decode("UTF-8")
return osp.splitext(file_name)[0]
_MNIST_ADAPTER = _TfdsAdapter(
category_transformers=[_SetLabelCategoriesFromClassLabelFeature("label")],
data_transformers=[
_SetImageFromImageFeature("image"),
_AddLabelFromClassLabelFeature("label"),
],
metadata=TfdsDatasetMetadata(
human_name="MNIST",
default_output_format="mnist",
media_type=Image,
),
)
_CIFAR_ADAPTER = _TfdsAdapter(
category_transformers=[_SetLabelCategoriesFromClassLabelFeature("label")],
data_transformers=[
_SetImageFromImageFeature("image"),
_AddLabelFromClassLabelFeature("label"),
],
id_generator=_GenerateIdFromTextFeature("id"),
metadata=TfdsDatasetMetadata(
human_name="CIFAR", default_output_format="cifar", media_type=Image
),
)
_COCO_ADAPTER = _TfdsAdapter(
category_transformers=[
_SetLabelCategoriesFromClassLabelFeature(("objects", "feature", "label"))
],
data_transformers=[
_SetImageFromImageFeature("image", filename_feature_name="image/filename"),
_AddObjectsFromFeature(
"objects", "bbox", label_member="label", attribute_members=("is_crowd",)
),
_SetAttributeFromFeature("image/id", "id"),
],
id_generator=_GenerateIdFromFilenameFeature("image/filename"),
metadata=TfdsDatasetMetadata(
human_name="COCO", default_output_format="coco_instances", media_type=Image
),
)
_IMAGENET_ADAPTER = _TfdsAdapter(
category_transformers=[_SetLabelCategoriesFromClassLabelFeature("label")],
data_transformers=[
_SetImageFromImageFeature("image", filename_feature_name="file_name"),
_AddLabelFromClassLabelFeature("label"),
],
id_generator=_GenerateIdFromFilenameFeature("file_name"),
metadata=TfdsDatasetMetadata(
human_name="ImageNet", default_output_format="imagenet_txt", media_type=Image
),
)
def _voc_save_pose_names(
tfds_builder: tfds.core.DatasetBuilder,
categories: CategoriesInfo,
state: namespace,
) -> None:
# TFDS represents poses as indexes, but Datumaro represents them as strings.
# To convert between representations, save the pose names at the start and
# use them when we're converting boxes.
# TFDS also provides the pose names in lower case, even though they're title
# case in the original dataset. Fix them back to title case so that the
# output dataset better resembles the original dataset.
state.pose_names = [
name.title() for name in tfds_builder.info.features["objects"].feature["pose"].names
]
_VOC_ADAPTER = _TfdsAdapter(
category_transformers=[
_SetLabelCategoriesFromClassLabelFeature(("objects", "feature", "label")),
_voc_save_pose_names,
],
data_transformers=[
_SetImageFromImageFeature("image", filename_feature_name="image/filename"),
_AddObjectsFromFeature(
"objects",
"bbox",
label_member="label",
attribute_members=(
_AttributeMemberMapping("is_difficult", "difficult"),
_AttributeMemberMapping("is_truncated", "truncated"),
_AttributeMemberMapping(
"pose", value_converter=lambda idx, state: state.pose_names[idx]
),
),
),
],
id_generator=_GenerateIdFromFilenameFeature("image/filename"),
metadata=TfdsDatasetMetadata(
human_name="PASCAL VOC", default_output_format="voc", media_type=Image
),
)
def _evolve_adapter_meta(adapter: _TfdsAdapter, **kwargs):
return attrs.evolve(adapter, metadata=attrs.evolve(adapter.metadata, **kwargs))
_TFDS_ADAPTERS = {
"cifar10": _evolve_adapter_meta(_CIFAR_ADAPTER, human_name="CIFAR-10"),
"cifar100": _evolve_adapter_meta(_CIFAR_ADAPTER, human_name="CIFAR-100"),
"coco/2014": _evolve_adapter_meta(_COCO_ADAPTER, human_name="COCO (2014-2015)"),
"imagenet_v2": _evolve_adapter_meta(_IMAGENET_ADAPTER, human_name="ImageNetV2"),
"mnist": _MNIST_ADAPTER,
"voc/2012": _evolve_adapter_meta(_VOC_ADAPTER, human_name="PASCAL VOC 2012"),
}
# Assign the TFDS catalog page as the documentation URL for all datasets.
_TFDS_ADAPTERS = {
name: _evolve_adapter_meta(
adapter,
home_url="https://www.tensorflow.org/datasets/catalog/" + name.split("/", maxsplit=1)[0],
)
for name, adapter in _TFDS_ADAPTERS.items()
}
class _TfdsSplitExtractor(IExtractor):
def __init__(
self,
parent: _TfdsExtractor,
tfds_split: tf.data.Dataset,
tfds_split_info: tfds.core.SplitInfo,
):
self._parent = parent
self._tfds_split = tfds_split
self._tfds_split_info = tfds_split_info
def __len__(self) -> int:
return self._tfds_split_info.num_examples
def __iter__(self) -> Iterator[DatasetItem]:
for example_index, tfds_example in enumerate(self._tfds_split):
if self._parent._adapter.id_generator:
item_id = self._parent._adapter.id_generator(tfds_example)
else:
item_id = str(example_index)
dm_item = DatasetItem(id=item_id, subset=self._tfds_split_info.name)
self._parent._adapter.transform_data(tfds_example, dm_item, self._parent._state)
yield dm_item
def categories(self) -> CategoriesInfo:
return self._parent.categories()
def subsets(self) -> Dict[str, IExtractor]:
return {self._tfds_split_info.name: self}
def get_subset(self, name) -> IExtractor:
assert name == self._tfds_split_info.name
return self
def get(self, id, subset=None) -> Optional[DatasetItem]:
if subset is not None and subset != self._tfds_split_info.name:
return None
for item in self:
if item.id == id:
return item
return None
def media_type(self) -> Type[MediaElement]:
return self._parent._media_type
class _TfdsExtractor(IExtractor):
_categories: CategoriesInfo
def __init__(self, tfds_ds_name: str) -> None:
self._adapter = _TFDS_ADAPTERS[tfds_ds_name]
tfds_builder = tfds.builder(tfds_ds_name)
tfds_ds_info = tfds_builder.info
self._categories = {}
self._state = namespace()
self._adapter.transform_categories(tfds_builder, self._categories, self._state)
self._media_type = self._adapter.metadata.media_type
tfds_decoders = {}
for tfds_feature_name, tfds_fc in tfds_ds_info.features.items():
if isinstance(tfds_fc, tfds.features.Image):
tfds_decoders[tfds_feature_name] = tfds.decode.SkipDecoding()
tfds_builder.download_and_prepare()
self._tfds_ds = tfds_builder.as_dataset(decoders=tfds_decoders)
self._split_extractors = {
split_name: _TfdsSplitExtractor(self, split, tfds_ds_info.splits[split_name])
# Since dicts in Python 3.7+ (and de facto in 3.6+) are
# order-preserving, sort the splits by name so that we always
# iterate over them in alphabetical order.
for split_name, split in sorted(self._tfds_ds.items())
}
def __len__(self) -> int:
return sum(len(ex) for ex in self._split_extractors.values())
def __iter__(self) -> Iterator[DatasetItem]:
return itertools.chain.from_iterable(self._split_extractors.values())
def categories(self) -> CategoriesInfo:
return self._categories
def subsets(self) -> Dict[str, IExtractor]:
return self._split_extractors
def get_subset(self, name) -> IExtractor:
return self._split_extractors[name]
def get(self, id, subset=None) -> Optional[DatasetItem]:
if subset is None:
for ex in self._split_extractors.values():
item = ex.get(id)
if item is not None:
return item
return None
if subset not in self._split_extractors:
return None
return self._split_extractors[subset].get(id)
def media_type(self) -> Type[MediaElement]:
return self._media_type
# Some dataset metadata elements are either inconvenient to hardcode, or may change
# depending on the version of TFDS. We fetch them from the attributes of the `tfds.Builder`
# object. However, creating the builder may be time-consuming, because if the dataset
# is not already downloaded, TFDS fetches some data from its Google Cloud bucket.
# We therefore only fetch this metadata (which we call _remote_ metadata) when
# we actually need it.
# If you add attributes to either of these metadata classes, make sure to update
# the reporting logic in the `describe-downloads` command to include them.
[docs]class TfdsDataset:
[docs] def __init__(self, tfds_ds_name: str):
self._tfds_ds_name = tfds_ds_name
self._adapter = _TFDS_ADAPTERS[tfds_ds_name]
@property
def metadata(self) -> TfdsDatasetMetadata:
return self._adapter.metadata
AVAILABLE_TFDS_DATASETS: Mapping[str, TfdsDataset] = {
name: TfdsDataset(name) for name in _TFDS_ADAPTERS
}