Source code for datumaro.plugins.coco_format.importer

# Copyright (C) 2019-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT

from glob import glob
import logging as log
import os.path as osp

from datumaro.components.extractor import DEFAULT_SUBSET_NAME, Importer
from datumaro.components.format_detection import (
    FormatDetectionConfidence, FormatDetectionContext,
)
from datumaro.plugins.coco_format.extractor import (
    CocoCaptionsExtractor, CocoImageInfoExtractor, CocoInstancesExtractor,
    CocoLabelsExtractor, CocoPanopticExtractor, CocoPersonKeypointsExtractor,
    CocoStuffExtractor,
)

from .format import CocoTask


[docs]class CocoImporter(Importer): _TASKS = { CocoTask.instances: CocoInstancesExtractor, CocoTask.person_keypoints: CocoPersonKeypointsExtractor, CocoTask.captions: CocoCaptionsExtractor, CocoTask.labels: CocoLabelsExtractor, CocoTask.image_info: CocoImageInfoExtractor, CocoTask.panoptic: CocoPanopticExtractor, CocoTask.stuff: CocoStuffExtractor, }
[docs] @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) parser.add_argument('--keep-original-category-ids', action='store_true', help="Add dummy label categories so that category indices " "correspond to the category IDs in the original annotation " "file") return parser
[docs] @classmethod def detect( cls, context: FormatDetectionContext, ) -> FormatDetectionConfidence: # The `coco` format is inherently ambiguous with `coco_instances`, # `coco_stuff`, etc. To remove the ambiguity (and thus make it possible # to use autodetection with the COCO dataset), disable autodetection # for the single-task formats. if len(cls._TASKS) == 1: context.raise_unsupported() with context.require_any(): for task in cls._TASKS.keys(): with context.alternative(): context.require_file(f'annotations/{task.name}_*.json')
def __call__(self, path, **extra_params): subsets = self.find_sources(path) if len(subsets) == 0: raise Exception("Failed to find 'coco' dataset at '%s'" % path) # TODO: should be removed when proper label merging is implemented conflicting_types = {CocoTask.instances, CocoTask.person_keypoints, CocoTask.labels, CocoTask.panoptic, CocoTask.stuff} ann_types = set(t for s in subsets.values() for t in s) \ & conflicting_types if 1 <= len(ann_types): selected_ann_type = sorted(ann_types, key=lambda x: x.name)[0] if 1 < len(ann_types): log.warning("Not implemented: " "Found potentially conflicting source types with labels: %s. " "Only one type will be used: %s" \ % (", ".join(t.name for t in ann_types), selected_ann_type.name)) sources = [] for ann_files in subsets.values(): for ann_type, ann_file in ann_files.items(): if ann_type in conflicting_types: if ann_type is not selected_ann_type: log.warning("Not implemented: " "conflicting source '%s' is skipped." % ann_file) continue log.info("Found a dataset at '%s'" % ann_file) sources.append({ 'url': ann_file, 'format': self._TASKS[ann_type].NAME, 'options': dict(extra_params), }) return sources
[docs] @classmethod def find_sources(cls, path): def detect_coco_task(filename): for task in CocoTask: if filename.startswith(task.name + '_'): return task return None if osp.isfile(path): if len(cls._TASKS) == 1: return {'': { next(iter(cls._TASKS)): path }} subset_paths = [path] if path.endswith('.json') else [] else: subset_paths = glob(osp.join(path, '**', '*_*.json'), recursive=True) subsets = {} for subset_path in subset_paths: ann_type = detect_coco_task(osp.basename(subset_path)) if ann_type is None and len(cls._TASKS) == 1: ann_type = list(cls._TASKS)[0] if ann_type not in cls._TASKS: log.warning("File '%s' was skipped, could't match this file " "with any of these tasks: %s" % (subset_path, ','.join(e.NAME for e in cls._TASKS.values())) ) continue parts = osp.splitext(osp.basename(subset_path))[0] \ .split(ann_type.name + '_', maxsplit=1) subset_name = parts[1] if len(parts) == 2 else DEFAULT_SUBSET_NAME subsets.setdefault(subset_name, {})[ann_type] = subset_path return subsets
[docs]class CocoImageInfoImporter(CocoImporter): _TASK = CocoTask.image_info _TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoCaptionsImporter(CocoImporter): _TASK = CocoTask.captions _TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoInstancesImporter(CocoImporter): _TASK = CocoTask.instances _TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoPersonKeypointsImporter(CocoImporter): _TASK = CocoTask.person_keypoints _TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoLabelsImporter(CocoImporter): _TASK = CocoTask.labels _TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoPanopticImporter(CocoImporter): _TASK = CocoTask.panoptic _TASKS = { _TASK: CocoImporter._TASKS[_TASK] }
[docs]class CocoStuffImporter(CocoImporter): _TASK = CocoTask.stuff _TASKS = { _TASK: CocoImporter._TASKS[_TASK] }