Source code for datumaro.plugins.cifar_format
# Copyright (C) 2020-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT
from collections import OrderedDict
import os
import os.path as osp
import pickle # nosec - disable B403:import_pickle check - fixed
import numpy as np
import numpy.core.multiarray
from datumaro.components.annotation import (
AnnotationType, Label, LabelCategories,
)
from datumaro.components.converter import Converter
from datumaro.components.dataset import ItemStatus
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.util import cast
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file
[docs]class RestrictedUnpickler(pickle.Unpickler):
[docs] def find_class(self, module, name):
if module == "numpy.core.multiarray" and \
name in PickleLoader.safe_numpy:
return getattr(numpy.core.multiarray, name)
elif module == 'numpy' and name in PickleLoader.safe_numpy:
return getattr(numpy, name)
raise pickle.UnpicklingError("Global '%s.%s' is forbidden"
% (module, name))
[docs]class CifarPath:
META_10_FILE = 'batches.meta'
META_100_FILE = 'meta'
TRAIN_FILE_PREFIX = 'data_batch_'
USELESS_FILE = 'file.txt~'
IMAGE_SIZE = 32
Cifar10Label = ['airplane', 'automobile', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# Support for Python version CIFAR-10/100
[docs]class CifarExtractor(SourceExtractor):
[docs] def __init__(self, path, subset=None):
if not osp.isfile(path):
raise FileNotFoundError("Can't read annotation file '%s'" % path)
if not subset:
subset = osp.splitext(osp.basename(path))[0]
super().__init__(subset=subset)
self._categories = self._load_categories(osp.dirname(path))
self._items = list(self._load_items(path).values())
def _load_categories(self, path):
if has_meta_file(path):
return { AnnotationType.label: LabelCategories.
from_iterable(parse_meta_file(path).keys()) }
label_cat = LabelCategories()
meta_file = osp.join(path, CifarPath.META_10_FILE)
if not osp.isfile(meta_file):
meta_file = osp.join(path, CifarPath.META_100_FILE)
if osp.isfile(meta_file):
# CIFAR-10:
# num_cases_per_batch: 1000
# label_names: ['airplane', 'automobile', 'bird', 'cat', 'deer',
# 'dog', 'frog', 'horse', 'ship', 'truck']
# num_vis: 3072
# CIFAR-100:
# fine_label_names: ['apple', 'aquarium_fish', 'baby', ...]
# coarse_label_names: ['aquatic_mammals', 'fish', 'flowers', ...]
with open(meta_file, 'rb') as labels_file:
data = PickleLoader.restricted_load(labels_file)
labels = data.get('label_names')
if labels is not None:
for label in labels:
label_cat.add(label)
else:
labels = data.get('fine_label_names')
self._coarse_labels = data.get('coarse_label_names', [])
if labels is not None:
for label in labels:
label_cat.add(label)
else:
for label in Cifar10Label:
label_cat.add(label)
return { AnnotationType.label: label_cat }
def _load_items(self, path):
items = {}
label_cat = self._categories[AnnotationType.label]
# 'batch_label': 'training batch 1 of 5'
# 'data': ndarray
# 'filenames': list
# CIFAR-10: 'labels': list
# CIFAR-100: 'fine_labels': list
# 'coarse_labels': list
with open(path, 'rb') as anno_file:
annotation_dict = PickleLoader.restricted_load(anno_file)
labels = annotation_dict.get('labels', [])
coarse_labels = annotation_dict.get('coarse_labels', [])
if len(labels) == 0:
labels = annotation_dict.get('fine_labels', [])
filenames = annotation_dict.get('filenames', [])
images_data = annotation_dict.get('data')
size = annotation_dict.get('image_sizes')
if len(labels) != len(filenames):
raise Exception("The sizes of the arrays 'filenames', " \
"'labels' don't match.")
if 0 < len(images_data) and len(images_data) != len(filenames):
raise Exception("The sizes of the arrays 'data', " \
"'filenames', 'labels' don't match.")
for i, (filename, label) in enumerate(zip(filenames, labels)):
item_id = osp.splitext(filename)[0]
annotations = []
if label is not None:
annotations.append(Label(label))
if 0 < len(coarse_labels) and coarse_labels[i] is not None and \
label_cat[label].parent == '':
label_cat[label].parent = \
self._coarse_labels[coarse_labels[i]]
image = None
if 0 < len(images_data):
image = images_data[i]
if size is not None and image is not None:
image = image.astype(np.uint8) \
.reshape(3, size[i][0], size[i][1])
image = np.transpose(image, (1, 2, 0))
elif image is not None:
image = image.astype(np.uint8) \
.reshape(3, CifarPath.IMAGE_SIZE, CifarPath.IMAGE_SIZE)
image = np.transpose(image, (1, 2, 0))
items[item_id] = DatasetItem(id=item_id, subset=self._subset,
image=image, annotations=annotations)
return items
[docs]class CifarImporter(Importer):
[docs] @classmethod
def find_sources(cls, path):
return cls._find_sources_recursive(path, '', 'cifar',
file_filter=lambda p: \
# subset files have no extension in the format
not osp.splitext(osp.basename(p))[1] and \
osp.basename(p) not in {
CifarPath.META_10_FILE, CifarPath.META_100_FILE,
CifarPath.USELESS_FILE
}
)
[docs]class CifarConverter(Converter):
DEFAULT_IMAGE_EXT = '.png'
[docs] def apply(self):
os.makedirs(self._save_dir, exist_ok=True)
if self._save_dataset_meta:
self._save_meta_file(self._save_dir)
label_categories = self._extractor.categories()[AnnotationType.label]
label_names = []
coarse_label_names = []
for label in label_categories:
label_names.append(label.name)
if label.parent and (label.parent not in coarse_label_names):
coarse_label_names.append(label.parent)
coarse_label_names.sort()
if coarse_label_names:
labels_dict = { 'fine_label_names': label_names,
'coarse_label_names': coarse_label_names }
coarse_label_names = OrderedDict((name, i)
for i, name in enumerate(coarse_label_names))
meta_file = osp.join(self._save_dir, CifarPath.META_100_FILE)
else:
labels_dict = { 'label_names': label_names }
meta_file = osp.join(self._save_dir, CifarPath.META_10_FILE)
with open(meta_file, 'wb') as f:
pickle.dump(labels_dict, f)
for subset_name, subset in self._extractor.subsets().items():
filenames = []
labels = []
coarse_labels = []
data = []
image_sizes = {}
for item in subset:
filenames.append(self._make_image_filename(item))
anns = [a for a in item.annotations
if a.type == AnnotationType.label]
if anns:
labels.append(anns[0].label)
if coarse_label_names:
superclass = label_categories[anns[0].label].parent
coarse_labels.append(coarse_label_names[superclass])
else:
labels.append(None)
coarse_labels.append(None)
if self._save_images and item.has_image:
image = item.image
if not image.has_data:
data.append(None)
else:
image = image.data
data.append(np.transpose(image, (2, 0, 1)) \
.reshape(-1).astype(np.uint8))
if image.shape[0] != CifarPath.IMAGE_SIZE or \
image.shape[1] != CifarPath.IMAGE_SIZE:
image_sizes[len(data) - 1] = \
(image.shape[0], image.shape[1])
annotation_dict = {}
annotation_dict['filenames'] = filenames
if labels and (len(labels) == len(coarse_labels)):
annotation_dict['fine_labels'] = labels
annotation_dict['coarse_labels'] = coarse_labels
else:
annotation_dict['labels'] = labels
annotation_dict['data'] = np.array(data, dtype=object)
if image_sizes:
size = (CifarPath.IMAGE_SIZE, CifarPath.IMAGE_SIZE)
# 'image_sizes' isn't included in the standard format,
# needed for different image sizes
annotation_dict['image_sizes'] = [image_sizes.get(p, size)
for p in range(len(data))]
batch_label = None
if subset_name.startswith(CifarPath.TRAIN_FILE_PREFIX):
num = subset_name[len(CifarPath.TRAIN_FILE_PREFIX):]
if cast(num, int) is not None:
batch_label = 'training batch %s of 5' % num
elif subset_name == 'test':
batch_label = 'testing batch 1 of 1'
if batch_label:
annotation_dict['batch_label'] = batch_label
annotation_file = osp.join(self._save_dir, subset_name)
if self._patch and subset_name in self._patch.updated_subsets and \
not annotation_dict['filenames']:
if osp.isfile(annotation_file):
# Remove subsets that became empty
os.remove(annotation_file)
continue
with open(annotation_file, 'wb') as labels_file:
pickle.dump(annotation_dict, labels_file)
[docs] @classmethod
def patch(cls, dataset, patch, save_dir, **kwargs):
for subset in patch.updated_subsets:
conv = cls(dataset.get_subset(subset), save_dir=save_dir, **kwargs)
conv._patch = patch
conv.apply()
for subset, status in patch.updated_subsets.items():
if status != ItemStatus.removed:
continue
subset_file = osp.join(save_dir, subset)
if osp.isfile(subset_file):
os.remove(subset_file)