Source code for datumaro.plugins.cifar_format

# Copyright (C) 2020-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

from collections import OrderedDict
import os
import os.path as osp
import pickle  # nosec - disable B403:import_pickle check - fixed

import numpy as np
import numpy.core.multiarray

from datumaro.components.annotation import (
    AnnotationType, Label, LabelCategories,
)
from datumaro.components.converter import Converter
from datumaro.components.dataset import ItemStatus
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.util import cast
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file


[docs]class RestrictedUnpickler(pickle.Unpickler):
[docs] def find_class(self, module, name): if module == "numpy.core.multiarray" and \ name in PickleLoader.safe_numpy: return getattr(numpy.core.multiarray, name) elif module == 'numpy' and name in PickleLoader.safe_numpy: return getattr(numpy, name) raise pickle.UnpicklingError("Global '%s.%s' is forbidden" % (module, name))
[docs]class PickleLoader(): safe_numpy = { 'dtype', 'ndarray', '_reconstruct', }
[docs] def restricted_load(s): return RestrictedUnpickler(s, encoding='latin1').load()
[docs]class CifarPath: META_10_FILE = 'batches.meta' META_100_FILE = 'meta' TRAIN_FILE_PREFIX = 'data_batch_' USELESS_FILE = 'file.txt~' IMAGE_SIZE = 32
Cifar10Label = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] # Support for Python version CIFAR-10/100
[docs]class CifarExtractor(SourceExtractor):
[docs] def __init__(self, path, subset=None): if not osp.isfile(path): raise FileNotFoundError("Can't read annotation file '%s'" % path) if not subset: subset = osp.splitext(osp.basename(path))[0] super().__init__(subset=subset) self._categories = self._load_categories(osp.dirname(path)) self._items = list(self._load_items(path).values())
def _load_categories(self, path): if has_meta_file(path): return { AnnotationType.label: LabelCategories. from_iterable(parse_meta_file(path).keys()) } label_cat = LabelCategories() meta_file = osp.join(path, CifarPath.META_10_FILE) if not osp.isfile(meta_file): meta_file = osp.join(path, CifarPath.META_100_FILE) if osp.isfile(meta_file): # CIFAR-10: # num_cases_per_batch: 1000 # label_names: ['airplane', 'automobile', 'bird', 'cat', 'deer', # 'dog', 'frog', 'horse', 'ship', 'truck'] # num_vis: 3072 # CIFAR-100: # fine_label_names: ['apple', 'aquarium_fish', 'baby', ...] # coarse_label_names: ['aquatic_mammals', 'fish', 'flowers', ...] with open(meta_file, 'rb') as labels_file: data = PickleLoader.restricted_load(labels_file) labels = data.get('label_names') if labels is not None: for label in labels: label_cat.add(label) else: labels = data.get('fine_label_names') self._coarse_labels = data.get('coarse_label_names', []) if labels is not None: for label in labels: label_cat.add(label) else: for label in Cifar10Label: label_cat.add(label) return { AnnotationType.label: label_cat } def _load_items(self, path): items = {} label_cat = self._categories[AnnotationType.label] # 'batch_label': 'training batch 1 of 5' # 'data': ndarray # 'filenames': list # CIFAR-10: 'labels': list # CIFAR-100: 'fine_labels': list # 'coarse_labels': list with open(path, 'rb') as anno_file: annotation_dict = PickleLoader.restricted_load(anno_file) labels = annotation_dict.get('labels', []) coarse_labels = annotation_dict.get('coarse_labels', []) if len(labels) == 0: labels = annotation_dict.get('fine_labels', []) filenames = annotation_dict.get('filenames', []) images_data = annotation_dict.get('data') size = annotation_dict.get('image_sizes') if len(labels) != len(filenames): raise Exception("The sizes of the arrays 'filenames', " \ "'labels' don't match.") if 0 < len(images_data) and len(images_data) != len(filenames): raise Exception("The sizes of the arrays 'data', " \ "'filenames', 'labels' don't match.") for i, (filename, label) in enumerate(zip(filenames, labels)): item_id = osp.splitext(filename)[0] annotations = [] if label is not None: annotations.append(Label(label)) if 0 < len(coarse_labels) and coarse_labels[i] is not None and \ label_cat[label].parent == '': label_cat[label].parent = \ self._coarse_labels[coarse_labels[i]] image = None if 0 < len(images_data): image = images_data[i] if size is not None and image is not None: image = image.astype(np.uint8) \ .reshape(3, size[i][0], size[i][1]) image = np.transpose(image, (1, 2, 0)) elif image is not None: image = image.astype(np.uint8) \ .reshape(3, CifarPath.IMAGE_SIZE, CifarPath.IMAGE_SIZE) image = np.transpose(image, (1, 2, 0)) items[item_id] = DatasetItem(id=item_id, subset=self._subset, image=image, annotations=annotations) return items
[docs]class CifarImporter(Importer):
[docs] @classmethod def find_sources(cls, path): return cls._find_sources_recursive(path, '', 'cifar', file_filter=lambda p: \ # subset files have no extension in the format not osp.splitext(osp.basename(p))[1] and \ osp.basename(p) not in { CifarPath.META_10_FILE, CifarPath.META_100_FILE, CifarPath.USELESS_FILE } )
[docs]class CifarConverter(Converter): DEFAULT_IMAGE_EXT = '.png'
[docs] def apply(self): os.makedirs(self._save_dir, exist_ok=True) if self._save_dataset_meta: self._save_meta_file(self._save_dir) label_categories = self._extractor.categories()[AnnotationType.label] label_names = [] coarse_label_names = [] for label in label_categories: label_names.append(label.name) if label.parent and (label.parent not in coarse_label_names): coarse_label_names.append(label.parent) coarse_label_names.sort() if coarse_label_names: labels_dict = { 'fine_label_names': label_names, 'coarse_label_names': coarse_label_names } coarse_label_names = OrderedDict((name, i) for i, name in enumerate(coarse_label_names)) meta_file = osp.join(self._save_dir, CifarPath.META_100_FILE) else: labels_dict = { 'label_names': label_names } meta_file = osp.join(self._save_dir, CifarPath.META_10_FILE) with open(meta_file, 'wb') as f: pickle.dump(labels_dict, f) for subset_name, subset in self._extractor.subsets().items(): filenames = [] labels = [] coarse_labels = [] data = [] image_sizes = {} for item in subset: filenames.append(self._make_image_filename(item)) anns = [a for a in item.annotations if a.type == AnnotationType.label] if anns: labels.append(anns[0].label) if coarse_label_names: superclass = label_categories[anns[0].label].parent coarse_labels.append(coarse_label_names[superclass]) else: labels.append(None) coarse_labels.append(None) if self._save_images and item.has_image: image = item.image if not image.has_data: data.append(None) else: image = image.data data.append(np.transpose(image, (2, 0, 1)) \ .reshape(-1).astype(np.uint8)) if image.shape[0] != CifarPath.IMAGE_SIZE or \ image.shape[1] != CifarPath.IMAGE_SIZE: image_sizes[len(data) - 1] = \ (image.shape[0], image.shape[1]) annotation_dict = {} annotation_dict['filenames'] = filenames if labels and (len(labels) == len(coarse_labels)): annotation_dict['fine_labels'] = labels annotation_dict['coarse_labels'] = coarse_labels else: annotation_dict['labels'] = labels annotation_dict['data'] = np.array(data, dtype=object) if image_sizes: size = (CifarPath.IMAGE_SIZE, CifarPath.IMAGE_SIZE) # 'image_sizes' isn't included in the standard format, # needed for different image sizes annotation_dict['image_sizes'] = [image_sizes.get(p, size) for p in range(len(data))] batch_label = None if subset_name.startswith(CifarPath.TRAIN_FILE_PREFIX): num = subset_name[len(CifarPath.TRAIN_FILE_PREFIX):] if cast(num, int) is not None: batch_label = 'training batch %s of 5' % num elif subset_name == 'test': batch_label = 'testing batch 1 of 1' if batch_label: annotation_dict['batch_label'] = batch_label annotation_file = osp.join(self._save_dir, subset_name) if self._patch and subset_name in self._patch.updated_subsets and \ not annotation_dict['filenames']: if osp.isfile(annotation_file): # Remove subsets that became empty os.remove(annotation_file) continue with open(annotation_file, 'wb') as labels_file: pickle.dump(annotation_dict, labels_file)
[docs] @classmethod def patch(cls, dataset, patch, save_dir, **kwargs): for subset in patch.updated_subsets: conv = cls(dataset.get_subset(subset), save_dir=save_dir, **kwargs) conv._patch = patch conv.apply() for subset, status in patch.updated_subsets.items(): if status != ItemStatus.removed: continue subset_file = osp.join(save_dir, subset) if osp.isfile(subset_file): os.remove(subset_file)