# Copyright (C) 2019-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT
from glob import iglob
from typing import Any, Callable, Dict, Iterator, List, Optional
import os
import os.path as osp
from attr import attrib, attrs
import attr
import numpy as np
from datumaro.components.annotation import (
Annotation, AnnotationType, Categories,
)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.errors import DatasetNotFoundError
from datumaro.components.format_detection import (
FormatDetectionConfidence, FormatDetectionContext,
)
from datumaro.components.media import Image
from datumaro.util import is_method_redefined
from datumaro.util.attrs_util import default_if_none, not_empty
# Re-export some names from .annotation for backwards compatibility.
import datumaro.components.annotation # isort:skip
for _name in [
'Annotation', 'AnnotationType', 'Bbox', 'Caption', 'Categories',
'CompiledMask', 'Cuboid3d', 'Label', 'LabelCategories', 'Mask',
'MaskCategories', 'Points', 'PointsCategories', 'Polygon', 'RleMask',
]:
globals()[_name] = getattr(datumaro.components.annotation, _name)
DEFAULT_SUBSET_NAME = 'default'
[docs]@attrs(order=False)
class DatasetItem:
id: str = attrib(converter=lambda x: str(x).replace('\\', '/'),
validator=not_empty)
annotations: List[Annotation] = attrib(
factory=list, validator=default_if_none(list))
subset: str = attrib(converter=lambda v: v or DEFAULT_SUBSET_NAME,
default=None)
# TODO: introduce "media" field with type info. Replace image and pcd.
image: Optional[Image] = attrib(default=None)
# TODO: introduce pcd type like Image
point_cloud: Optional[str] = attrib(
converter=lambda x: str(x).replace('\\', '/') if x else None,
default=None)
related_images: List[Image] = attrib(default=None)
def __attrs_post_init__(self):
if (self.has_image and self.has_point_cloud):
raise ValueError("Can't set both image and point cloud info")
if self.related_images and not self.has_point_cloud:
raise ValueError("Related images require point cloud")
def _image_converter(image):
if callable(image) or isinstance(image, np.ndarray):
image = Image(data=image)
elif isinstance(image, str):
image = Image(path=image)
assert image is None or isinstance(image, Image), type(image)
return image
image.converter = _image_converter
def _related_image_converter(images):
return list(map(__class__._image_converter, images or []))
related_images.converter = _related_image_converter
@point_cloud.validator
def _point_cloud_validator(self, attribute, pcd):
assert pcd is None or isinstance(pcd, str), type(pcd)
attributes: Dict[str, Any] = attrib(
factory=dict, validator=default_if_none(dict))
@property
def has_image(self):
return self.image is not None
@property
def has_point_cloud(self):
return self.point_cloud is not None
[docs] def wrap(item, **kwargs):
return attr.evolve(item, **kwargs)
CategoriesInfo = Dict[AnnotationType, Categories]
class ExtractorBase(IExtractor):
def __init__(self, length=None, subsets=None):
self._length = length
self._subsets = subsets
def _init_cache(self):
subsets = set()
length = -1
for length, item in enumerate(self):
subsets.add(item.subset)
length += 1
if self._length is None:
self._length = length
if self._subsets is None:
self._subsets = subsets
def __len__(self):
if self._length is None:
self._init_cache()
return self._length
def subsets(self) -> Dict[str, IExtractor]:
if self._subsets is None:
self._init_cache()
return {name or DEFAULT_SUBSET_NAME: self.get_subset(name)
for name in self._subsets}
def get_subset(self, name):
if self._subsets is None:
self._init_cache()
if name in self._subsets:
if len(self._subsets) == 1:
return self
subset = self.select(lambda item: item.subset == name)
subset._subsets = [name]
return subset
else:
raise KeyError("Unknown subset '%s', available subsets: %s" % \
(name, set(self._subsets)))
def transform(self, method, *args, **kwargs):
return method(self, *args, **kwargs)
def select(self, pred):
class _DatasetFilter(ExtractorBase):
def __iter__(_):
return filter(pred, iter(self))
def categories(_):
return self.categories()
return _DatasetFilter()
def categories(self):
return {}
def get(self, id, subset=None):
subset = subset or DEFAULT_SUBSET_NAME
for item in self:
if item.id == id and item.subset == subset:
return item
return None
[docs]class Importer(CliPlugin):
[docs] @classmethod
def detect(
cls, context: FormatDetectionContext,
) -> Optional[FormatDetectionConfidence]:
if not cls.find_sources_with_params(context.root_path):
context.fail("specific requirement information unavailable")
return FormatDetectionConfidence.LOW
[docs] @classmethod
def find_sources(cls, path) -> List[Dict]:
raise NotImplementedError()
[docs] @classmethod
def find_sources_with_params(cls, path, **extra_params) -> List[Dict]:
return cls.find_sources(path)
def __call__(self, path, **extra_params):
if not path or not osp.exists(path):
raise DatasetNotFoundError(path)
found_sources = self.find_sources_with_params(
osp.normpath(path), **extra_params)
if not found_sources:
raise DatasetNotFoundError(path)
sources = []
for desc in found_sources:
params = dict(extra_params)
params.update(desc.get('options', {}))
desc['options'] = params
sources.append(desc)
return sources
[docs] @classmethod
def _find_sources_recursive(cls, path: str, ext: Optional[str],
extractor_name: str, filename: str = '*', dirname: str = '',
file_filter: Optional[Callable[[str], bool]] = None,
max_depth: int = 3):
"""
Finds sources in the specified location, using the matching pattern
to filter file names and directories.
Supposed to be used, and to be the only call in subclasses.
Parameters:
path: a directory or file path, where sources need to be found.
ext: file extension to match. To match directories,
set this parameter to None or ''. Comparison is case-independent,
a starting dot is not required.
extractor_name: the name of the associated Extractor type
filename: a glob pattern for file names
dirname: a glob pattern for filename prefixes
file_filter: a callable (abspath: str) -> bool, to filter paths found
max_depth: the maximum depth for recursive search.
Returns: a list of source configurations
(i.e. Extractor type names and c-tor parameters)
"""
if ext:
if not ext.startswith('.'):
ext = '.' + ext
ext = ext.lower()
if (path.lower().endswith(ext) and osp.isfile(path)) or \
(not ext and dirname and osp.isdir(path) and \
os.sep + osp.normpath(dirname.lower()) + os.sep in \
osp.abspath(path.lower()) + os.sep):
sources = [{'url': path, 'format': extractor_name}]
else:
sources = []
for d in range(max_depth + 1):
sources.extend({'url': p, 'format': extractor_name} for p in
iglob(osp.join(path, *('*' * d), dirname, filename + ext))
if (callable(file_filter) and file_filter(p)) \
or (not callable(file_filter)))
if sources:
break
return sources