Source code for datumaro.components.extractor_tfds

# Copyright (C) 2021-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT

from __future__ import annotations

import itertools
import logging as log
import os.path as osp
from types import SimpleNamespace as namespace
from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union

import attrs
from attrs import field, frozen

from datumaro.components.annotation import AnnotationType, Bbox, Label, LabelCategories
from datumaro.components.extractor import CategoriesInfo, DatasetItem, IExtractor
from datumaro.components.media import ByteImage, Image, MediaElement
from datumaro.util.tf_util import import_tf

try:
    tf = import_tf()
    import tensorflow_datasets as tfds
except ImportError:
    log.debug(
        "Unable to import TensorFlow or TensorFlow Datasets. "
        "Dataset downloading via TFDS is disabled."
    )
    TFDS_EXTRACTOR_AVAILABLE = False
else:
    TFDS_EXTRACTOR_AVAILABLE = True


[docs]@frozen(kw_only=True) class TfdsDatasetMetadata: # If you add attributes to this class, make sure to update the reporting logic # in the `describe-downloads` command to include them. human_name: str default_output_format: str media_type: Type[MediaElement] # Every TFDS dataset has a home page (the TFDS documentation page), but # this field is still optional for a couple of reasons: # * We might want to reuse this class later if we implement downloading # without the use of TFDS, and at that point it might happen that a dataset # has no home page (e.g. it's been taken down). # * It's convenient to initially leave this field blank and mass-update it later, # and in order for us to not break typing, it should be optional. home_url: Optional[str] = None
@frozen class _TfdsAdapter: category_transformers: Sequence[ Callable[[tfds.core.DatasetBuilder, CategoriesInfo, namespace], None] ] data_transformers: Sequence[Callable[[Any, DatasetItem, namespace], None]] id_generator: Callable[[Any], str] = field(default=None, kw_only=True) metadata: TfdsDatasetMetadata def transform_categories( self, tfds_builder: tfds.core.DatasetBuilder, categories: CategoriesInfo, state: namespace, ) -> None: for t in self.category_transformers: t(tfds_builder, categories, state) def transform_data( self, tfds_example: Any, item: DatasetItem, state: namespace, ) -> None: for t in self.data_transformers: t(tfds_example, item, state) _FeaturePath = Union[str, Tuple[str, ...]] def _resolve_feature_path( feature_path: _FeaturePath, root: tfds.features.FeaturesDict, ) -> tfds.features.FeatureConnector: if isinstance(feature_path, str): return root[feature_path] feature_connector = root for segment in feature_path: assert isinstance( feature_connector, ( tfds.features.FeaturesDict, tfds.features.Sequence, ), ) if isinstance(feature_connector, tfds.features.Sequence): assert segment == "feature" feature_connector = feature_connector.feature else: feature_connector = feature_connector[segment] return feature_connector @frozen class _SetLabelCategoriesFromClassLabelFeature: feature_path: _FeaturePath def __call__( self, tfds_builder: tfds.core.DatasetBuilder, categories: CategoriesInfo, state: namespace, ) -> None: assert AnnotationType.label not in categories feature_connector = _resolve_feature_path(self.feature_path, tfds_builder.info.features) assert isinstance(feature_connector, tfds.features.ClassLabel) categories[AnnotationType.label] = LabelCategories.from_iterable(feature_connector.names) @frozen class _SetImageFromImageFeature: feature_name: str filename_feature_name: Optional[str] = field(default=None) def __call__( self, tfds_example: Any, item: DatasetItem, state: namespace, ) -> None: if self.filename_feature_name: filename = tfds_example[self.filename_feature_name].numpy().decode("UTF-8") else: filename = None item.media = ByteImage(data=tfds_example[self.feature_name].numpy(), path=filename) @frozen class _AddLabelFromClassLabelFeature: feature_name: str def __call__( self, tfds_example: Any, item: DatasetItem, state: namespace, ) -> None: item.annotations.append( Label(tfds_example[self.feature_name].numpy()), ) @frozen class _AttributeMemberMapping: member_name: str attribute_name: str = field() value_converter: Optional[Callable[[Any, namespace], Any]] = None @attribute_name.default def _attribute_name_default(self): return self.member_name @frozen class _AddObjectsFromFeature: feature_name: str bbox_member: str label_member: Optional[str] = field(default=None, kw_only=True) attribute_members: Tuple[_AttributeMemberMapping, ...] = field( default=(), kw_only=True, converter=lambda values: tuple( value if isinstance(value, _AttributeMemberMapping) else _AttributeMemberMapping(value) for value in values ), ) def __call__( self, tfds_example: Any, item: DatasetItem, state: namespace, ) -> None: tfds_objects = tfds_example[self.feature_name] tfds_bboxes = tfds_objects[self.bbox_member] num_objects = tfds_bboxes.shape[0] tfds_labels = None if self.label_member is not None: tfds_labels = tfds_objects[self.label_member] assert tfds_labels.shape[0] == num_objects for am_mapping in self.attribute_members: assert tfds_objects[am_mapping.member_name].shape[0] == num_objects for i in range(num_objects): norm_ymin, norm_xmin, norm_ymax, norm_xmax = tfds_bboxes[i].numpy() new_bbox = Bbox( x=norm_xmin * item.media.size[1], y=norm_ymin * item.media.size[0], w=(norm_xmax - norm_xmin) * item.media.size[1], h=(norm_ymax - norm_ymin) * item.media.size[0], ) if tfds_labels is not None: new_bbox.label = tfds_labels[i].numpy() for am_mapping in self.attribute_members: attr_value = tfds_objects[am_mapping.member_name][i].numpy() if am_mapping.value_converter: attr_value = am_mapping.value_converter(attr_value, state) new_bbox.attributes[am_mapping.attribute_name] = attr_value item.annotations.append(new_bbox) @frozen class _SetAttributeFromFeature: feature_name: str attribute_name: str def __call__( self, tfds_example: Any, item: DatasetItem, state: namespace, ) -> None: item.attributes[self.attribute_name] = tfds_example[self.feature_name].numpy() @frozen class _GenerateIdFromTextFeature: feature_name: str def __call__(self, tfds_example: Any) -> str: return tfds_example[self.feature_name].numpy().decode("UTF-8") @frozen class _GenerateIdFromFilenameFeature: feature_name: str def __call__(self, tfds_example: Any) -> str: file_name = tfds_example[self.feature_name].numpy().decode("UTF-8") return osp.splitext(file_name)[0] _MNIST_ADAPTER = _TfdsAdapter( category_transformers=[_SetLabelCategoriesFromClassLabelFeature("label")], data_transformers=[ _SetImageFromImageFeature("image"), _AddLabelFromClassLabelFeature("label"), ], metadata=TfdsDatasetMetadata( human_name="MNIST", default_output_format="mnist", media_type=Image, ), ) _CIFAR_ADAPTER = _TfdsAdapter( category_transformers=[_SetLabelCategoriesFromClassLabelFeature("label")], data_transformers=[ _SetImageFromImageFeature("image"), _AddLabelFromClassLabelFeature("label"), ], id_generator=_GenerateIdFromTextFeature("id"), metadata=TfdsDatasetMetadata( human_name="CIFAR", default_output_format="cifar", media_type=Image ), ) _COCO_ADAPTER = _TfdsAdapter( category_transformers=[ _SetLabelCategoriesFromClassLabelFeature(("objects", "feature", "label")) ], data_transformers=[ _SetImageFromImageFeature("image", filename_feature_name="image/filename"), _AddObjectsFromFeature( "objects", "bbox", label_member="label", attribute_members=("is_crowd",) ), _SetAttributeFromFeature("image/id", "id"), ], id_generator=_GenerateIdFromFilenameFeature("image/filename"), metadata=TfdsDatasetMetadata( human_name="COCO", default_output_format="coco_instances", media_type=Image ), ) _IMAGENET_ADAPTER = _TfdsAdapter( category_transformers=[_SetLabelCategoriesFromClassLabelFeature("label")], data_transformers=[ _SetImageFromImageFeature("image", filename_feature_name="file_name"), _AddLabelFromClassLabelFeature("label"), ], id_generator=_GenerateIdFromFilenameFeature("file_name"), metadata=TfdsDatasetMetadata( human_name="ImageNet", default_output_format="imagenet_txt", media_type=Image ), ) def _voc_save_pose_names( tfds_builder: tfds.core.DatasetBuilder, categories: CategoriesInfo, state: namespace, ) -> None: # TFDS represents poses as indexes, but Datumaro represents them as strings. # To convert between representations, save the pose names at the start and # use them when we're converting boxes. # TFDS also provides the pose names in lower case, even though they're title # case in the original dataset. Fix them back to title case so that the # output dataset better resembles the original dataset. state.pose_names = [ name.title() for name in tfds_builder.info.features["objects"].feature["pose"].names ] _VOC_ADAPTER = _TfdsAdapter( category_transformers=[ _SetLabelCategoriesFromClassLabelFeature(("objects", "feature", "label")), _voc_save_pose_names, ], data_transformers=[ _SetImageFromImageFeature("image", filename_feature_name="image/filename"), _AddObjectsFromFeature( "objects", "bbox", label_member="label", attribute_members=( _AttributeMemberMapping("is_difficult", "difficult"), _AttributeMemberMapping("is_truncated", "truncated"), _AttributeMemberMapping( "pose", value_converter=lambda idx, state: state.pose_names[idx] ), ), ), ], id_generator=_GenerateIdFromFilenameFeature("image/filename"), metadata=TfdsDatasetMetadata( human_name="PASCAL VOC", default_output_format="voc", media_type=Image ), ) def _evolve_adapter_meta(adapter: _TfdsAdapter, **kwargs): return attrs.evolve(adapter, metadata=attrs.evolve(adapter.metadata, **kwargs)) _TFDS_ADAPTERS = { "cifar10": _evolve_adapter_meta(_CIFAR_ADAPTER, human_name="CIFAR-10"), "cifar100": _evolve_adapter_meta(_CIFAR_ADAPTER, human_name="CIFAR-100"), "coco/2014": _evolve_adapter_meta(_COCO_ADAPTER, human_name="COCO (2014-2015)"), "imagenet_v2": _evolve_adapter_meta(_IMAGENET_ADAPTER, human_name="ImageNetV2"), "mnist": _MNIST_ADAPTER, "voc/2012": _evolve_adapter_meta(_VOC_ADAPTER, human_name="PASCAL VOC 2012"), } # Assign the TFDS catalog page as the documentation URL for all datasets. _TFDS_ADAPTERS = { name: _evolve_adapter_meta( adapter, home_url="https://www.tensorflow.org/datasets/catalog/" + name.split("/", maxsplit=1)[0], ) for name, adapter in _TFDS_ADAPTERS.items() } class _TfdsSplitExtractor(IExtractor): def __init__( self, parent: _TfdsExtractor, tfds_split: tf.data.Dataset, tfds_split_info: tfds.core.SplitInfo, ): self._parent = parent self._tfds_split = tfds_split self._tfds_split_info = tfds_split_info def __len__(self) -> int: return self._tfds_split_info.num_examples def __iter__(self) -> Iterator[DatasetItem]: for example_index, tfds_example in enumerate(self._tfds_split): if self._parent._adapter.id_generator: item_id = self._parent._adapter.id_generator(tfds_example) else: item_id = str(example_index) dm_item = DatasetItem(id=item_id, subset=self._tfds_split_info.name) self._parent._adapter.transform_data(tfds_example, dm_item, self._parent._state) yield dm_item def categories(self) -> CategoriesInfo: return self._parent.categories() def subsets(self) -> Dict[str, IExtractor]: return {self._tfds_split_info.name: self} def get_subset(self, name) -> IExtractor: assert name == self._tfds_split_info.name return self def get(self, id, subset=None) -> Optional[DatasetItem]: if subset is not None and subset != self._tfds_split_info.name: return None for item in self: if item.id == id: return item return None def media_type(self) -> Type[MediaElement]: return self._parent._media_type class _TfdsExtractor(IExtractor): _categories: CategoriesInfo def __init__(self, tfds_ds_name: str) -> None: self._adapter = _TFDS_ADAPTERS[tfds_ds_name] tfds_builder = tfds.builder(tfds_ds_name) tfds_ds_info = tfds_builder.info self._categories = {} self._state = namespace() self._adapter.transform_categories(tfds_builder, self._categories, self._state) self._media_type = self._adapter.metadata.media_type tfds_decoders = {} for tfds_feature_name, tfds_fc in tfds_ds_info.features.items(): if isinstance(tfds_fc, tfds.features.Image): tfds_decoders[tfds_feature_name] = tfds.decode.SkipDecoding() tfds_builder.download_and_prepare() self._tfds_ds = tfds_builder.as_dataset(decoders=tfds_decoders) self._split_extractors = { split_name: _TfdsSplitExtractor(self, split, tfds_ds_info.splits[split_name]) # Since dicts in Python 3.7+ (and de facto in 3.6+) are # order-preserving, sort the splits by name so that we always # iterate over them in alphabetical order. for split_name, split in sorted(self._tfds_ds.items()) } def __len__(self) -> int: return sum(len(ex) for ex in self._split_extractors.values()) def __iter__(self) -> Iterator[DatasetItem]: return itertools.chain.from_iterable(self._split_extractors.values()) def categories(self) -> CategoriesInfo: return self._categories def subsets(self) -> Dict[str, IExtractor]: return self._split_extractors def get_subset(self, name) -> IExtractor: return self._split_extractors[name] def get(self, id, subset=None) -> Optional[DatasetItem]: if subset is None: for ex in self._split_extractors.values(): item = ex.get(id) if item is not None: return item return None if subset not in self._split_extractors: return None return self._split_extractors[subset].get(id) def media_type(self) -> Type[MediaElement]: return self._media_type # Some dataset metadata elements are either inconvenient to hardcode, or may change # depending on the version of TFDS. We fetch them from the attributes of the `tfds.Builder` # object. However, creating the builder may be time-consuming, because if the dataset # is not already downloaded, TFDS fetches some data from its Google Cloud bucket. # We therefore only fetch this metadata (which we call _remote_ metadata) when # we actually need it. # If you add attributes to either of these metadata classes, make sure to update # the reporting logic in the `describe-downloads` command to include them.
[docs]@frozen(kw_only=True) class TfdsSubsetRemoteMetadata: num_items: int
[docs]@frozen(kw_only=True) class TfdsDatasetRemoteMetadata(TfdsDatasetMetadata): # For convenience, the remote metadata also includes the local metadata. description: str download_size: int subsets: Mapping[str, TfdsSubsetRemoteMetadata] # the dataset might not have the notion of classes, so `num_classes` may # be None. num_classes: Optional[int] version: str
[docs]class TfdsDataset:
[docs] def __init__(self, tfds_ds_name: str): self._tfds_ds_name = tfds_ds_name self._adapter = _TFDS_ADAPTERS[tfds_ds_name]
@property def metadata(self) -> TfdsDatasetMetadata: return self._adapter.metadata
[docs] def make_extractor(self) -> IExtractor: return _TfdsExtractor(self._tfds_ds_name)
[docs] def query_remote_metadata(self) -> TfdsDatasetRemoteMetadata: tfds_builder = tfds.builder(self._tfds_ds_name) categories = {} state = namespace() self._adapter.transform_categories(tfds_builder, categories, state) num_classes = None if AnnotationType.label in categories: num_classes = len(categories[AnnotationType.label]) return TfdsDatasetRemoteMetadata( **attrs.asdict(self._adapter.metadata), description=tfds_builder.info.description, download_size=int(tfds_builder.info.download_size), num_classes=num_classes, subsets={ name: TfdsSubsetRemoteMetadata(num_items=split_info.num_examples) for name, split_info in tfds_builder.info.splits.items() }, version=str(tfds_builder.info.version), )
AVAILABLE_TFDS_DATASETS: Mapping[str, TfdsDataset] = { name: TfdsDataset(name) for name in _TFDS_ADAPTERS }