# Copyright (C) 2019-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import os
import os.path as osp
import warnings
from glob import iglob
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import attr
import numpy as np
from attr import attrs, define, field
from datumaro.components.annotation import Annotation, AnnotationType, Categories
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.errors import (
AnnotationImportError,
DatasetNotFoundError,
DatumaroError,
ItemImportError,
)
from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext
from datumaro.components.media import Image, MediaElement, PointCloud
from datumaro.components.progress_reporting import NullProgressReporter, ProgressReporter
from datumaro.util import is_method_redefined
from datumaro.util.attrs_util import default_if_none, not_empty
DEFAULT_SUBSET_NAME = "default"
T = TypeVar("T", bound=MediaElement)
[docs]@attrs(order=False, init=False, slots=True)
class DatasetItem:
id: str = field(converter=lambda x: str(x).replace("\\", "/"), validator=not_empty)
subset: str = field(converter=lambda v: v or DEFAULT_SUBSET_NAME, default=None)
media: Optional[MediaElement] = field(
default=None, validator=attr.validators.optional(attr.validators.instance_of(MediaElement))
)
annotations: List[Annotation] = field(factory=list, validator=default_if_none(list))
attributes: Dict[str, Any] = field(factory=dict, validator=default_if_none(dict))
[docs] def wrap(item, **kwargs):
return attr.evolve(item, **kwargs)
[docs] def __init__(
self,
id: str,
*,
subset: Optional[str] = None,
media: Union[str, MediaElement, None] = None,
annotations: Optional[List[Annotation]] = None,
attributes: Dict[str, Any] = None,
image=None,
point_cloud=None,
related_images=None,
):
if image is not None:
warnings.warn(
"'image' is deprecated and will be " "removed in future. Use 'media' instead.",
DeprecationWarning,
stacklevel=2,
)
if isinstance(image, str):
image = Image(path=image)
elif isinstance(image, np.ndarray) or callable(image):
image = Image(data=image)
assert isinstance(image, Image)
media = image
elif point_cloud is not None:
warnings.warn(
"'point_cloud' is deprecated and will be "
"removed in future. Use 'media' instead.",
DeprecationWarning,
stacklevel=2,
)
if related_images is not None:
warnings.warn(
"'related_images' is deprecated and will be "
"removed in future. Use 'media' instead.",
DeprecationWarning,
stacklevel=2,
)
if isinstance(point_cloud, str):
point_cloud = PointCloud(path=point_cloud, extra_images=related_images)
assert isinstance(point_cloud, PointCloud)
media = point_cloud
self.__attrs_init__(
id=id, subset=subset, media=media, annotations=annotations, attributes=attributes
)
# Deprecated. Provided for backward compatibility.
@property
def image(self) -> Optional[Image]:
warnings.warn(
"'DatasetItem.image' is deprecated and will be "
"removed in future. Use '.media' and '.media_as()' instead.",
DeprecationWarning,
stacklevel=2,
)
if not isinstance(self.media, Image):
return None
return self.media_as(Image)
# Deprecated. Provided for backward compatibility.
@property
def point_cloud(self) -> Optional[str]:
warnings.warn(
"'DatasetItem.point_cloud' is deprecated and will be "
"removed in future. Use '.media' and '.media_as()' instead.",
DeprecationWarning,
stacklevel=2,
)
if not isinstance(self.media, PointCloud):
return None
return self.media_as(PointCloud).path
# Deprecated. Provided for backward compatibility.
@property
def related_images(self) -> List[Image]:
warnings.warn(
"'DatasetItem.related_images' is deprecated and will be "
"removed in future. Use '.media' and '.media_as()' instead.",
DeprecationWarning,
stacklevel=2,
)
if not isinstance(self.media, PointCloud):
return []
return self.media_as(PointCloud).extra_images
# Deprecated. Provided for backward compatibility.
@property
def has_image(self):
warnings.warn(
"'DatasetItem.has_image' is deprecated and will be "
"removed in future. Use '.media' and '.media_as()' instead.",
DeprecationWarning,
stacklevel=2,
)
return isinstance(self.media, Image)
# Deprecated. Provided for backward compatibility.
@property
def has_point_cloud(self):
warnings.warn(
"'DatasetItem.has_point_cloud' is deprecated and will be "
"removed in future. Use '.media' and '.media_as()' instead.",
DeprecationWarning,
stacklevel=2,
)
return isinstance(self.media, PointCloud)
CategoriesInfo = Dict[AnnotationType, Categories]
class _ExtractorBase(IExtractor):
def __init__(self, *, length: Optional[int] = None, subsets: Optional[Sequence[str]] = None):
self._length = length
self._subsets = subsets
def _init_cache(self):
subsets = set()
length = -1
for length, item in enumerate(self):
subsets.add(item.subset)
length += 1
if self._length is None:
self._length = length
if self._subsets is None:
self._subsets = subsets
def __len__(self):
if self._length is None:
self._init_cache()
return self._length
def subsets(self) -> Dict[str, IExtractor]:
if self._subsets is None:
self._init_cache()
return {name or DEFAULT_SUBSET_NAME: self.get_subset(name) for name in self._subsets}
def get_subset(self, name):
if self._subsets is None:
self._init_cache()
if name in self._subsets:
if len(self._subsets) == 1:
return self
subset = self.select(lambda item: item.subset == name)
subset._subsets = [name]
return subset
else:
raise KeyError(
"Unknown subset '%s', available subsets: %s" % (name, set(self._subsets))
)
def transform(self, method, *args, **kwargs):
return method(self, *args, **kwargs)
def select(self, pred):
class _DatasetFilter(_ExtractorBase):
def __iter__(_):
return filter(pred, iter(self))
def categories(_):
return self.categories()
def media_type(_):
return self.media_type()
return _DatasetFilter()
def categories(self):
return {}
def get(self, id, subset=None):
subset = subset or DEFAULT_SUBSET_NAME
for item in self:
if item.id == id and item.subset == subset:
return item
return None
T = TypeVar("T")
class _ImportFail(DatumaroError):
pass
class ImportErrorPolicy:
def report_item_error(self, error: Exception, *, item_id: Tuple[str, str]) -> None:
"""
Allows to report a problem with a dataset item.
If this function returns, the extractor must skip the item.
"""
if not isinstance(error, _ImportFail):
ie = ItemImportError(item_id)
ie.__cause__ = error
return self._handle_item_error(ie)
else:
raise error
def report_annotation_error(self, error: Exception, *, item_id: Tuple[str, str]) -> None:
"""
Allows to report a problem with a dataset item annotation.
If this function returns, the extractor must skip the annotation.
"""
if not isinstance(error, _ImportFail):
ie = AnnotationImportError(item_id)
ie.__cause__ = error
return self._handle_annotation_error(ie)
else:
raise error
def _handle_item_error(self, error: ItemImportError) -> None:
"""This function must either call fail() or return."""
self.fail(error)
def _handle_annotation_error(self, error: AnnotationImportError) -> None:
"""This function must either call fail() or return."""
self.fail(error)
def fail(self, error: Exception) -> NoReturn:
raise _ImportFail from error
class FailingImportErrorPolicy(ImportErrorPolicy):
pass
@define(eq=False)
class ImportContext:
progress_reporter: ProgressReporter = field(
default=None, converter=attr.converters.default_if_none(factory=NullProgressReporter)
)
error_policy: ImportErrorPolicy = field(
default=None, converter=attr.converters.default_if_none(factory=FailingImportErrorPolicy)
)
class NullImportContext(ImportContext):
pass
[docs]class Importer(CliPlugin):
[docs] @classmethod
def detect(
cls,
context: FormatDetectionContext,
) -> Optional[FormatDetectionConfidence]:
if not cls.find_sources_with_params(context.root_path):
context.fail("specific requirement information unavailable")
return FormatDetectionConfidence.LOW
[docs] @classmethod
def find_sources(cls, path) -> List[Dict]:
raise NotImplementedError()
[docs] @classmethod
def find_sources_with_params(cls, path, **extra_params) -> List[Dict]:
return cls.find_sources(path)
def __call__(self, path, **extra_params):
if not path or not osp.exists(path):
raise DatasetNotFoundError(path)
found_sources = self.find_sources_with_params(osp.normpath(path), **extra_params)
if not found_sources:
raise DatasetNotFoundError(path)
sources = []
for desc in found_sources:
params = dict(extra_params)
params.update(desc.get("options", {}))
desc["options"] = params
sources.append(desc)
return sources
[docs] @classmethod
def _find_sources_recursive(
cls,
path: str,
ext: Optional[str],
extractor_name: str,
filename: str = "*",
dirname: str = "",
file_filter: Optional[Callable[[str], bool]] = None,
max_depth: int = 3,
):
"""
Finds sources in the specified location, using the matching pattern
to filter file names and directories.
Supposed to be used, and to be the only call in subclasses.
Parameters:
path: a directory or file path, where sources need to be found.
ext: file extension to match. To match directories,
set this parameter to None or ''. Comparison is case-independent,
a starting dot is not required.
extractor_name: the name of the associated Extractor type
filename: a glob pattern for file names
dirname: a glob pattern for filename prefixes
file_filter: a callable (abspath: str) -> bool, to filter paths found
max_depth: the maximum depth for recursive search.
Returns: a list of source configurations
(i.e. Extractor type names and c-tor parameters)
"""
if ext:
if not ext.startswith("."):
ext = "." + ext
ext = ext.lower()
if (path.lower().endswith(ext) and osp.isfile(path)) or (
not ext
and dirname
and osp.isdir(path)
and os.sep + osp.normpath(dirname.lower()) + os.sep
in osp.abspath(path.lower()) + os.sep
):
sources = [{"url": path, "format": extractor_name}]
else:
sources = []
for d in range(max_depth + 1):
sources.extend(
{"url": p, "format": extractor_name}
for p in iglob(osp.join(path, *("*" * d), dirname, filename + ext))
if (callable(file_filter) and file_filter(p)) or (not callable(file_filter))
)
if sources:
break
return sources