Source code for datumaro.plugins.imagenet_txt_format

# Copyright (C) 2020-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os
import os.path as osp
from enum import Enum, auto
from typing import Iterable, Optional, Sequence, Tuple, Union

from datumaro.components.annotation import AnnotationType, Label, LabelCategories
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.converter import Converter
from datumaro.components.errors import DatasetImportError, MediaTypeError
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.media import Image
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file


[docs]class ImagenetTxtPath: LABELS_FILE = "synsets.txt" IMAGE_DIR = "images"
class _LabelsSource(Enum): file = auto() generate = auto() def _parse_annotation_line(line: str) -> Tuple[str, str, Sequence[int]]: item = line.split('"') if 1 < len(item): if len(item) == 3: item_id = item[1] item = item[2].split() image = item_id + item[0] label_ids = [int(id) for id in item[1:]] else: raise Exception("Line %s: unexpected number " "of quotes in filename" % line) else: item = line.split() item_id = osp.splitext(item[0])[0] image = item[0] label_ids = [int(id) for id in item[1:]] return item_id, image, label_ids
[docs]class ImagenetTxtExtractor(SourceExtractor):
[docs] def __init__( self, path: str, *, labels: Union[Iterable[str], str] = _LabelsSource.file.name, labels_file: str = ImagenetTxtPath.LABELS_FILE, image_dir: Optional[str] = None, subset: Optional[str] = None, ): assert osp.isfile(path), path if not subset: subset = osp.splitext(osp.basename(path))[0] super().__init__(subset=subset) root_dir = osp.dirname(path) if not image_dir: image_dir = ImagenetTxtPath.IMAGE_DIR self.image_dir = osp.join(root_dir, image_dir) self._generate_labels = False if isinstance(labels, str): labels_source = _LabelsSource[labels] if labels_source == _LabelsSource.generate: labels = () self._generate_labels = True elif labels_source == _LabelsSource.file: if has_meta_file(root_dir): labels = parse_meta_file(root_dir).keys() else: labels = self._parse_labels(osp.join(root_dir, labels_file)) else: assert False, "Unhandled labels source %s" % labels_source else: assert all(isinstance(e, str) for e in labels) self._categories = self._load_categories(labels) self._items = list(self._load_items(path).values())
@staticmethod def _parse_labels(path): with open(path, encoding="utf-8") as labels_file: return [s.strip() for s in labels_file] def _load_categories(self, labels): return {AnnotationType.label: LabelCategories.from_iterable(labels)} def _load_items(self, path): items = {} with open(path, encoding="utf-8") as f: for line in f: item_id, image, label_ids = _parse_annotation_line(line) anno = [] label_categories = self._categories[AnnotationType.label] for label in label_ids: if label < 0: raise DatasetImportError(f"Image '{item_id}': invalid label id '{label}'") if len(label_categories) <= label: if self._generate_labels: while len(label_categories) <= label: label_categories.add(f"class-{len(label_categories)}") else: raise DatasetImportError( f"Image '{item_id}': unknown label id '{label}'" ) anno.append(Label(label)) items[item_id] = DatasetItem( id=item_id, subset=self._subset, media=Image(path=osp.join(self.image_dir, image)), annotations=anno, ) return items
[docs]class ImagenetTxtImporter(Importer, CliPlugin):
[docs] @classmethod def detect(cls, context: FormatDetectionContext) -> None: annot_path = context.require_file("*.txt", exclude_fnames=ImagenetTxtPath.LABELS_FILE) with context.probe_text_file( annot_path, "must be an ImageNet-like annotation file", ) as f: for line in f: _, _, label_ids = _parse_annotation_line(line) if label_ids: break else: # If there are no labels in the entire file, it's probably # not actually an ImageNet file. raise Exception
[docs] @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) parser.add_argument( "--labels", choices=_LabelsSource.__members__, default=_LabelsSource.file.name, help="Where to get label descriptions from (use " "'file' to load from the file specified by --labels-file; " "'generate' to create generic ones)", ) parser.add_argument( "--labels-file", default=ImagenetTxtPath.LABELS_FILE, help="Path to the file with label descriptions (synsets.txt)", ) return parser
[docs] @classmethod def find_sources_with_params(cls, path, **extra_params): if "labels" not in extra_params or extra_params["labels"] == _LabelsSource.file.name: labels_file_name = osp.basename( extra_params.get("labels_file") or ImagenetTxtPath.LABELS_FILE ) def file_filter(p): return osp.basename(p) != labels_file_name else: file_filter = None return cls._find_sources_recursive(path, ".txt", "imagenet_txt", file_filter=file_filter)
[docs]class ImagenetTxtConverter(Converter): DEFAULT_IMAGE_EXT = ".jpg"
[docs] def apply(self): if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image): raise MediaTypeError("Media type is not an image") subset_dir = self._save_dir os.makedirs(subset_dir, exist_ok=True) extractor = self._extractor for subset_name, subset in self._extractor.subsets().items(): annotation_file = osp.join(subset_dir, "%s.txt" % subset_name) labels = {} for item in subset: item_id = item.id if 1 < len(item_id.split()): item_id = '"' + item_id + '"' item_id += self._find_image_ext(item) labels[item_id] = set( p.label for p in item.annotations if p.type == AnnotationType.label ) if self._save_media and item.media: self._save_image(item, subdir=ImagenetTxtPath.IMAGE_DIR) annotation = "" for item_id, item_labels in labels.items(): annotation += "%s %s\n" % (item_id, " ".join(str(l) for l in item_labels)) with open(annotation_file, "w", encoding="utf-8") as f: f.write(annotation) if self._save_dataset_meta: self._save_meta_file(subset_dir) else: labels_file = osp.join(subset_dir, ImagenetTxtPath.LABELS_FILE) with open(labels_file, "w", encoding="utf-8") as f: f.writelines( l.name + "\n" for l in extractor.categories().get(AnnotationType.label, LabelCategories()) )