# Copyright (C) 2020-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT
import os
import os.path as osp
from enum import Enum, auto
from typing import Iterable, Optional, Sequence, Tuple, Union
from datumaro.components.annotation import AnnotationType, Label, LabelCategories
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.converter import Converter
from datumaro.components.errors import DatasetImportError, MediaTypeError
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.media import Image
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file
[docs]class ImagenetTxtPath:
LABELS_FILE = "synsets.txt"
IMAGE_DIR = "images"
class _LabelsSource(Enum):
file = auto()
generate = auto()
def _parse_annotation_line(line: str) -> Tuple[str, str, Sequence[int]]:
item = line.split('"')
if 1 < len(item):
if len(item) == 3:
item_id = item[1]
item = item[2].split()
image = item_id + item[0]
label_ids = [int(id) for id in item[1:]]
else:
raise Exception("Line %s: unexpected number " "of quotes in filename" % line)
else:
item = line.split()
item_id = osp.splitext(item[0])[0]
image = item[0]
label_ids = [int(id) for id in item[1:]]
return item_id, image, label_ids
[docs]class ImagenetTxtImporter(Importer, CliPlugin):
[docs] @classmethod
def detect(cls, context: FormatDetectionContext) -> None:
annot_path = context.require_file("*.txt", exclude_fnames=ImagenetTxtPath.LABELS_FILE)
with context.probe_text_file(
annot_path,
"must be an ImageNet-like annotation file",
) as f:
for line in f:
_, _, label_ids = _parse_annotation_line(line)
if label_ids:
break
else:
# If there are no labels in the entire file, it's probably
# not actually an ImageNet file.
raise Exception
[docs] @classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument(
"--labels",
choices=_LabelsSource.__members__,
default=_LabelsSource.file.name,
help="Where to get label descriptions from (use "
"'file' to load from the file specified by --labels-file; "
"'generate' to create generic ones)",
)
parser.add_argument(
"--labels-file",
default=ImagenetTxtPath.LABELS_FILE,
help="Path to the file with label descriptions (synsets.txt)",
)
return parser
[docs] @classmethod
def find_sources_with_params(cls, path, **extra_params):
if "labels" not in extra_params or extra_params["labels"] == _LabelsSource.file.name:
labels_file_name = osp.basename(
extra_params.get("labels_file") or ImagenetTxtPath.LABELS_FILE
)
def file_filter(p):
return osp.basename(p) != labels_file_name
else:
file_filter = None
return cls._find_sources_recursive(path, ".txt", "imagenet_txt", file_filter=file_filter)
[docs]class ImagenetTxtConverter(Converter):
DEFAULT_IMAGE_EXT = ".jpg"
[docs] def apply(self):
if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image):
raise MediaTypeError("Media type is not an image")
subset_dir = self._save_dir
os.makedirs(subset_dir, exist_ok=True)
extractor = self._extractor
for subset_name, subset in self._extractor.subsets().items():
annotation_file = osp.join(subset_dir, "%s.txt" % subset_name)
labels = {}
for item in subset:
item_id = item.id
if 1 < len(item_id.split()):
item_id = '"' + item_id + '"'
item_id += self._find_image_ext(item)
labels[item_id] = set(
p.label for p in item.annotations if p.type == AnnotationType.label
)
if self._save_media and item.media:
self._save_image(item, subdir=ImagenetTxtPath.IMAGE_DIR)
annotation = ""
for item_id, item_labels in labels.items():
annotation += "%s %s\n" % (item_id, " ".join(str(l) for l in item_labels))
with open(annotation_file, "w", encoding="utf-8") as f:
f.write(annotation)
if self._save_dataset_meta:
self._save_meta_file(subset_dir)
else:
labels_file = osp.join(subset_dir, ImagenetTxtPath.LABELS_FILE)
with open(labels_file, "w", encoding="utf-8") as f:
f.writelines(
l.name + "\n"
for l in extractor.categories().get(AnnotationType.label, LabelCategories())
)