# Copyright (C) 2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
from collections import defaultdict
from random import Random
from typing import List, Mapping, Optional, Tuple
import argparse
from datumaro.components.annotation import AnnotationType
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.extractor import DatasetItem, IExtractor, Transform
from datumaro.util import cast
[docs]class RandomSampler(Transform, CliPlugin):
"""
Sampler that keeps no more than required number of items in the dataset.|n
|n
Notes:|n
- Items are selected uniformly|n
- Requesting a sample larger than the number of all images will
return all images|n
|n
Example: select subset of 20 images randomly|n
|s|s%(prog)s -k 20 |n
Example: select subset of 20 images, modify only 'train' subset|n
|s|s%(prog)s -k 20 -s train
"""
[docs] @classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-k', '--count', type=int, required=True,
help="Maximum number of items to sample")
parser.add_argument('-s', '--subset', default=None,
help="Limit changes to this subset (default: affect all dataset)")
parser.add_argument('--seed', type=int,
help="Initial value for random number generator")
return parser
def __init__(self, extractor: IExtractor, count: int, *,
subset: Optional[str] = None, seed: Optional[int] = None):
super().__init__(extractor)
self._seed = seed
self._count = count
self._indices = None
self._subset = subset
def __iter__(self):
if self._indices is None:
rng = Random(self._seed)
if self._subset:
n = len(self._extractor.get_subset(self._subset))
else:
n = len(self._extractor)
self._indices = rng.sample(range(n), min(self._count, n))
self._indices.sort()
idx_iter = iter(self._indices)
try:
next_pick = next(idx_iter)
except StopIteration:
if not self._subset:
return
else:
next_pick = -1
i = 0
for item in self._extractor:
if self._subset and self._subset != item.subset:
yield item
else:
if i == next_pick:
yield item
try:
next_pick = next(idx_iter)
except StopIteration:
if self._subset:
next_pick = -1
continue
else:
return
i += 1
[docs]class LabelRandomSampler(Transform, CliPlugin):
"""
Sampler that keeps at least the required number of annotations of
each class in the dataset for each subset separately.|n
|n
Consider using the "stats" command to get class distribution in
the dataset.|n
|n
Notes:|n
- Items can contain annotations of several selected classes
(e.g. 3 bounding boxes per image). The number of annotations in the
resulting dataset varies between max(class counts) and sum(class counts)|n
- If the input dataset does not has enough class annotations, the result
will contain only what is available|n
- Items are selected uniformly|n
- For reasons above, the resulting class distribution in the dataset may
not be the same as requested|n
- The resulting dataset will only keep annotations for
classes with specified count > 0|n
|n
Example: select at least 5 annotations of each class randomly|n
|s|s%(prog)s -k 5 |n
Example: select at least 5 images with "cat" annotations and 3 "person"|n
|s|s%(prog)s -l "cat:5" -l "person:3"
"""
@staticmethod
def _parse_label_count(s: str) -> Tuple[str, int]:
label, count = s.split(':', maxsplit=1)
count = cast(count, int, default=None)
if not label:
raise argparse.ArgumentError(None, "Class name cannot be empty")
if count is None or count < 0:
raise argparse.ArgumentError(None,
f"Class '{label}' count is invalid")
return label, count
[docs] @classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-k', '--count', type=int, required=True,
help="Minimum number of annotations of each class")
parser.add_argument('-l', '--label', dest='label_counts',
action='append', type=cls._parse_label_count,
help="Minimum number of annotations of a specific class. "
"Overrides the `-k/--count` setting for the class. "
"The format is 'label_name:count' (repeatable)")
parser.add_argument('--seed', type=int,
help="Initial value for random number generator")
return parser
def __init__(self, extractor: IExtractor, *,
count: Optional[int] = None,
label_counts: Optional[Mapping[str, int]] = None,
seed: Optional[int] = None):
from datumaro.plugins.transforms import ProjectLabels
count = count or 0
label_counts = dict(label_counts or {})
assert count or any(label_counts.values())
new_labels = {}
for label in extractor.categories()[AnnotationType.label]:
label_count = label_counts.get(label.name, count)
if label_count:
new_labels[label.name] = label_count
self._label_counts = { idx: count
for idx, count in enumerate(new_labels.values()) }
super().__init__(ProjectLabels(extractor, new_labels.keys()))
self._seed = seed
# for repeated calls
self._selected_items: List[DatasetItem] = None
def __iter__(self):
if self._selected_items is not None:
yield from self._selected_items
return
# Uses the reservoir sampling algorithm for each class
# https://en.wikipedia.org/wiki/Reservoir_sampling
def _make_bucket():
# label -> bucket
return { label: [] for label in self._label_counts }
buckets = defaultdict(_make_bucket) # subset -> subset_buckets
rng = Random(self._seed)
for i, item in enumerate(self._extractor):
labels = set(getattr(ann, 'label', None)
for ann in item.annotations)
labels.discard(None)
for label in labels:
if len(buckets[item.subset][label]) < self._label_counts[label]:
buckets[item.subset][label].append(item)
else:
j = rng.randint(1, i)
if j <= self._label_counts[label]:
buckets[item.subset][label][j - 1] = item
selected_items = {}
for subset_buckets in buckets.values():
for label_bucket in subset_buckets.values():
for item in label_bucket:
if item:
selected_items.setdefault((item.id, item.subset), item)
self._selected_items = selected_items.values()
yield from self._selected_items