Source code for datumaro.components.launcher

# Copyright (C) 2019-2020 Intel Corporation
#
# SPDX-License-Identifier: MIT

import numpy as np

from datumaro.components.annotation import AnnotationType, LabelCategories
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.extractor import Transform
from datumaro.util import take_by


# pylint: disable=no-self-use
[docs]class Launcher(CliPlugin):
[docs] def __init__(self, model_dir=None): pass
[docs] def launch(self, inputs): raise NotImplementedError()
[docs] def categories(self): return None
# pylint: enable=no-self-use
[docs]class ModelTransform(Transform):
[docs] def __init__(self, extractor, launcher, batch_size=1): super().__init__(extractor) self._launcher = launcher self._batch_size = batch_size
[docs] def __iter__(self): for batch in take_by(self._extractor, self._batch_size): inputs = np.array([np.atleast_3d(item.media.data) for item in batch]) inference = self._launcher.launch(inputs) for item, annotations in zip(batch, inference): self._check_annotations(annotations) yield self.wrap_item(item, annotations=annotations)
[docs] def get_subset(self, name): subset = self._extractor.get_subset(name) return __class__(subset, self._launcher, self._batch_size)
[docs] def categories(self): launcher_override = self._launcher.categories() if launcher_override is not None: return launcher_override return self._extractor.categories()
[docs] def transform_item(self, item): inputs = np.expand_dims(item.media, axis=0) annotations = self._launcher.launch(inputs)[0] return self.wrap_item(item, annotations=annotations)
def _check_annotations(self, annotations): labels_count = len(self.categories().get(AnnotationType.label, LabelCategories()).items) for ann in annotations: label = getattr(ann, "label") if label is None: continue if label not in range(labels_count): raise Exception( "Annotation has unexpected label id %s, " "while there is only %s defined labels." % (label, labels_count) )