Source code for datumaro.components.launcher
# Copyright (C) 2019-2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
import numpy as np
from datumaro.components.annotation import AnnotationType, LabelCategories
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.extractor import Transform
from datumaro.util import take_by
# pylint: disable=no-self-use
# pylint: enable=no-self-use
[docs]class ModelTransform(Transform):
[docs] def __init__(self, extractor, launcher, batch_size=1):
super().__init__(extractor)
self._launcher = launcher
self._batch_size = batch_size
[docs] def __iter__(self):
for batch in take_by(self._extractor, self._batch_size):
inputs = np.array([np.atleast_3d(item.media.data) for item in batch])
inference = self._launcher.launch(inputs)
for item, annotations in zip(batch, inference):
self._check_annotations(annotations)
yield self.wrap_item(item, annotations=annotations)
[docs] def get_subset(self, name):
subset = self._extractor.get_subset(name)
return __class__(subset, self._launcher, self._batch_size)
[docs] def categories(self):
launcher_override = self._launcher.categories()
if launcher_override is not None:
return launcher_override
return self._extractor.categories()
[docs] def transform_item(self, item):
inputs = np.expand_dims(item.media, axis=0)
annotations = self._launcher.launch(inputs)[0]
return self.wrap_item(item, annotations=annotations)
def _check_annotations(self, annotations):
labels_count = len(self.categories().get(AnnotationType.label, LabelCategories()).items)
for ann in annotations:
label = getattr(ann, "label")
if label is None:
continue
if label not in range(labels_count):
raise Exception(
"Annotation has unexpected label id %s, "
"while there is only %s defined labels." % (label, labels_count)
)