# Copyright (C) 2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
from typing import Dict, Optional, Type, Union
import inspect
import os
import os.path as osp
import shutil
from datumaro.components.converter import Converter
from datumaro.components.dataset import (
Dataset, DatasetItemStorageDatasetView, IDataset,
)
from datumaro.components.dataset_filter import (
XPathAnnotationsFilter, XPathDatasetFilter,
)
from datumaro.components.environment import Environment
from datumaro.components.extractor import Transform
from datumaro.components.launcher import Launcher, ModelTransform
from datumaro.components.operations import ExactMerge
from datumaro.components.validator import TaskType, Validator
from datumaro.util import parse_str_enum_value
from datumaro.util.scope import on_error_do, scoped
[docs]def filter(dataset: IDataset, expr: str, *, #pylint: disable=redefined-builtin
filter_annotations: bool = False,
remove_empty: bool = False) -> IDataset:
"""
Filters out some dataset items or annotations, using a custom filter
expression.
Args:
dataset - The dataset to be filtered
expr - XPath-formatted filter expression
(e.g. `/item[subset = 'train']`, `/item/annotation[label = 'cat']`)
filter_annotations - Indicates if the filter should be
applied to items or annotations
remove_empty - When filtering annotations, allows to
exclude empty items from the resulting dataset
Returns: a wrapper around the input dataset, which is computed lazily
during iteration
"""
if filter_annotations:
return transform(dataset, XPathAnnotationsFilter,
xpath=expr, remove_empty=remove_empty)
else:
if not expr:
return dataset
return transform(dataset, XPathDatasetFilter, xpath=expr)
[docs]def merge(*datasets: IDataset) -> IDataset:
"""
Merges several datasets using the "simple" (exact matching) algorithm:
- items are matched by (id, subset) pairs
- matching items share the fields available
- nothing + nothing = nothing,
- nothing + something = something
- something A + something B = conflict
- annotations are matched by value and shared
- in case of conflicts, throws an error
Returns: a wrapper around the input datasets
"""
categories = ExactMerge.merge_categories(d.categories() for d in datasets)
return DatasetItemStorageDatasetView(ExactMerge.merge(*datasets),
categories=categories)
[docs]def run_model(dataset: IDataset,
model: Union[Launcher, Type[ModelTransform]], *,
batch_size: int = 1, **kwargs) -> IDataset:
"""
Applies a model to dataset items' media and produces a dataset with
media and annotations.
Args:
dataset - The dataset to be transformed
model - The model to be applied to the dataset
batch_size - The number of dataset items processed
simultaneously by the model
**kwargs - Parameters for the model
Returns: a wrapper around the input dataset, which is computed lazily
during iteration
"""
if isinstance(model, Launcher):
return transform(dataset, ModelTransform, launcher=model,
batch_size=batch_size, **kwargs)
elif inspect.isclass(model) and issubclass(model, ModelTransform):
return transform(dataset, model,
batch_size=batch_size, **kwargs)
else:
raise TypeError('Unexpected model argument type: %s' % type(model))
[docs]@scoped
def export(dataset: IDataset, path: str,
format: Union[str, Type[Converter]], *,
env: Optional[Environment] = None, **kwargs) -> None:
"""
Saves the input dataset in some format.
Args:
dataset - The dataset to be saved
path - The output directory
format - The desired output format for the dataset.
If a string is passed, it is treated as a plugin name,
which is searched for in the environment set by the 'env' argument
env - A plugin collection. If not set, the built-in plugins are used
**kwargs - Parameters for the export format
"""
if isinstance(format, str):
if env is None:
env = Environment()
converter = env.converters[format]
else:
converter = format
if not (inspect.isclass(converter) and issubclass(converter, Converter)):
raise TypeError("Unexpected 'format' argument type: %s" % \
type(converter))
path = osp.abspath(path)
if not osp.exists(path):
on_error_do(shutil.rmtree, path, ignore_errors=True)
os.makedirs(path, exist_ok=True)
converter.convert(dataset, save_dir=path, **kwargs)
[docs]def validate(dataset: IDataset, task: Union[str, TaskType], *,
env: Optional[Environment] = None, **kwargs) -> Dict:
"""
Checks dataset annotations for correctness relatively to a task type.
Args:
dataset - The dataset to check
task - Target task type - classification, detection etc.
env - A plugin collection. If not set, the built-in plugins are used
**kwargs - Parameters for the validator
Returns: a dictionary with validation results
"""
task = parse_str_enum_value(task, TaskType).name
if env is None:
env = Environment()
validator: Validator = env.validators[task](**kwargs)
return validator.validate(dataset)