# Copyright (C) 2019-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT
from glob import glob
import logging as log
import os.path as osp
from datumaro.components.extractor import DEFAULT_SUBSET_NAME, Importer
from datumaro.components.format_detection import (
FormatDetectionConfidence, FormatDetectionContext,
)
from datumaro.plugins.coco_format.extractor import (
CocoCaptionsExtractor, CocoImageInfoExtractor, CocoInstancesExtractor,
CocoLabelsExtractor, CocoPanopticExtractor, CocoPersonKeypointsExtractor,
CocoStuffExtractor,
)
from .format import CocoTask
[docs]class CocoImporter(Importer):
_TASKS = {
CocoTask.instances: CocoInstancesExtractor,
CocoTask.person_keypoints: CocoPersonKeypointsExtractor,
CocoTask.captions: CocoCaptionsExtractor,
CocoTask.labels: CocoLabelsExtractor,
CocoTask.image_info: CocoImageInfoExtractor,
CocoTask.panoptic: CocoPanopticExtractor,
CocoTask.stuff: CocoStuffExtractor,
}
[docs] @classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('--keep-original-category-ids', action='store_true',
help="Add dummy label categories so that category indices "
"correspond to the category IDs in the original annotation "
"file")
return parser
[docs] @classmethod
def detect(
cls, context: FormatDetectionContext,
) -> FormatDetectionConfidence:
# The `coco` format is inherently ambiguous with `coco_instances`,
# `coco_stuff`, etc. To remove the ambiguity (and thus make it possible
# to use autodetection with the COCO dataset), disable autodetection
# for the single-task formats.
if len(cls._TASKS) == 1:
context.fail('this format cannot be autodetected')
with context.require_any():
for task in cls._TASKS.keys():
with context.alternative():
context.require_file(f'annotations/{task.name}_*.json')
def __call__(self, path, **extra_params):
subsets = self.find_sources(path)
if len(subsets) == 0:
raise Exception("Failed to find 'coco' dataset at '%s'" % path)
# TODO: should be removed when proper label merging is implemented
conflicting_types = {CocoTask.instances,
CocoTask.person_keypoints, CocoTask.labels,
CocoTask.panoptic, CocoTask.stuff}
ann_types = set(t for s in subsets.values() for t in s) \
& conflicting_types
if 1 <= len(ann_types):
selected_ann_type = sorted(ann_types, key=lambda x: x.name)[0]
if 1 < len(ann_types):
log.warning("Not implemented: "
"Found potentially conflicting source types with labels: %s. "
"Only one type will be used: %s" \
% (", ".join(t.name for t in ann_types), selected_ann_type.name))
sources = []
for ann_files in subsets.values():
for ann_type, ann_file in ann_files.items():
if ann_type in conflicting_types:
if ann_type is not selected_ann_type:
log.warning("Not implemented: "
"conflicting source '%s' is skipped." % ann_file)
continue
log.info("Found a dataset at '%s'" % ann_file)
sources.append({
'url': ann_file,
'format': self._TASKS[ann_type].NAME,
'options': dict(extra_params),
})
return sources
[docs] @classmethod
def find_sources(cls, path):
def detect_coco_task(filename):
for task in CocoTask:
if filename.startswith(task.name + '_'):
return task
return None
if osp.isfile(path):
if len(cls._TASKS) == 1:
return {'': { next(iter(cls._TASKS)): path }}
subset_paths = [path] if path.endswith('.json') else []
else:
subset_paths = glob(osp.join(path, '**', '*_*.json'),
recursive=True)
subsets = {}
for subset_path in subset_paths:
ann_type = detect_coco_task(osp.basename(subset_path))
if ann_type is None and len(cls._TASKS) == 1:
ann_type = list(cls._TASKS)[0]
if ann_type not in cls._TASKS:
log.warning("File '%s' was skipped, could't match this file "
"with any of these tasks: %s" %
(subset_path, ','.join(e.NAME for e in cls._TASKS.values()))
)
continue
parts = osp.splitext(osp.basename(subset_path))[0] \
.split(ann_type.name + '_', maxsplit=1)
subset_name = parts[1] if len(parts) == 2 else DEFAULT_SUBSET_NAME
subsets.setdefault(subset_name, {})[ann_type] = subset_path
return subsets
[docs]class CocoImageInfoImporter(CocoImporter):
_TASK = CocoTask.image_info
_TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoCaptionsImporter(CocoImporter):
_TASK = CocoTask.captions
_TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoInstancesImporter(CocoImporter):
_TASK = CocoTask.instances
_TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoPersonKeypointsImporter(CocoImporter):
_TASK = CocoTask.person_keypoints
_TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoLabelsImporter(CocoImporter):
_TASK = CocoTask.labels
_TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoPanopticImporter(CocoImporter):
_TASK = CocoTask.panoptic
_TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoStuffImporter(CocoImporter):
_TASK = CocoTask.stuff
_TASKS = { _TASK: CocoImporter._TASKS[_TASK] }