Source code for datumaro.plugins.sampler.random_sampler

# Copyright (C) 2022 Intel Corporation
#
# SPDX-License-Identifier: MIT

import argparse
from collections import defaultdict
from random import Random
from typing import List, Mapping, Optional, Tuple

from datumaro.components.annotation import AnnotationType
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.extractor import DatasetItem, IExtractor, Transform
from datumaro.util import cast


[docs]class RandomSampler(Transform, CliPlugin): """ Sampler that keeps no more than required number of items in the dataset.|n |n Notes:|n |s|s- Items are selected uniformly|n |s|s- Requesting a sample larger than the number of all images will|n |s|s|s|sreturn all images|n |n Example: select subset of 20 images randomly|n .. code-block:: |s|s%(prog)s -k 20 |n |n Example: select subset of 20 images, modify only 'train' subset|n .. code-block:: |s|s%(prog)s -k 20 -s train """
[docs] @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) parser.add_argument( "-k", "--count", type=int, required=True, help="Maximum number of items to sample" ) parser.add_argument( "-s", "--subset", default=None, help="Limit changes to this subset (default: affect all dataset)", ) parser.add_argument("--seed", type=int, help="Initial value for random number generator") return parser
def __init__( self, extractor: IExtractor, count: int, *, subset: Optional[str] = None, seed: Optional[int] = None, ): super().__init__(extractor) self._seed = seed self._count = count self._indices = None self._subset = subset def __iter__(self): if self._indices is None: rng = Random(self._seed) if self._subset: n = len(self._extractor.get_subset(self._subset)) else: n = len(self._extractor) self._indices = rng.sample(range(n), min(self._count, n)) self._indices.sort() idx_iter = iter(self._indices) try: next_pick = next(idx_iter) except StopIteration: if not self._subset: return else: next_pick = -1 i = 0 for item in self._extractor: if self._subset and self._subset != item.subset: yield item else: if i == next_pick: yield item try: next_pick = next(idx_iter) except StopIteration: if self._subset: next_pick = -1 continue else: return i += 1
[docs]class LabelRandomSampler(Transform, CliPlugin): """ Sampler that keeps at least the required number of annotations of each class in the dataset for each subset separately.|n |n Consider using the "stats" command to get class distribution in the dataset.|n |n Notes:|n |s|s- Items can contain annotations of several selected classes|n |s|s|s|s(e.g. 3 bounding boxes per image). The number of annotations in the|n |s|s|s|sresulting dataset varies between max(class counts) and sum(class counts)|n |s|s- If the input dataset does not has enough class annotations, the result|n |s|s|s|swill contain only what is available|n |s|s- Items are selected uniformly|n |s|s- For reasons above, the resulting class distribution in the dataset may|n |s|s|s|snot be the same as requested|n |s|s- The resulting dataset will only keep annotations for|n |s|s|s|sclasses with specified count > 0|n |n Example: select at least 5 annotations of each class randomly|n .. code-block:: |s|s%(prog)s -k 5 |n |n Example: select at least 5 images with "cat" annotations and 3 "person"|n .. code-block:: |s|s%(prog)s -l "cat:5" -l "person:3" """ @staticmethod def _parse_label_count(s: str) -> Tuple[str, int]: label, count = s.split(":", maxsplit=1) count = cast(count, int, default=None) if not label: raise argparse.ArgumentError(None, "Class name cannot be empty") if count is None or count < 0: raise argparse.ArgumentError(None, f"Class '{label}' count is invalid") return label, count
[docs] @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) parser.add_argument( "-k", "--count", type=int, required=True, help="Minimum number of annotations of each class", ) parser.add_argument( "-l", "--label", dest="label_counts", action="append", type=cls._parse_label_count, help="Minimum number of annotations of a specific class. " "Overrides the `-k/--count` setting for the class. " "The format is 'label_name:count' (repeatable)", ) parser.add_argument("--seed", type=int, help="Initial value for random number generator") return parser
def __init__( self, extractor: IExtractor, *, count: Optional[int] = None, label_counts: Optional[Mapping[str, int]] = None, seed: Optional[int] = None, ): from datumaro.plugins.transforms import ProjectLabels count = count or 0 label_counts = dict(label_counts or {}) assert count or any(label_counts.values()) new_labels = {} for label in extractor.categories()[AnnotationType.label]: label_count = label_counts.get(label.name, count) if label_count: new_labels[label.name] = label_count self._label_counts = {idx: count for idx, count in enumerate(new_labels.values())} super().__init__(ProjectLabels(extractor, new_labels.keys())) self._seed = seed # for repeated calls self._selected_items: List[DatasetItem] = None def __iter__(self): if self._selected_items is not None: yield from self._selected_items return # Uses the reservoir sampling algorithm for each class # https://en.wikipedia.org/wiki/Reservoir_sampling def _make_bucket(): # label -> bucket return {label: [] for label in self._label_counts} buckets = defaultdict(_make_bucket) # subset -> subset_buckets rng = Random(self._seed) for i, item in enumerate(self._extractor): labels = set(getattr(ann, "label", None) for ann in item.annotations) labels.discard(None) for label in labels: if len(buckets[item.subset][label]) < self._label_counts[label]: buckets[item.subset][label].append(item) else: j = rng.randint(1, i) if j <= self._label_counts[label]: buckets[item.subset][label][j - 1] = item selected_items = {} for subset_buckets in buckets.values(): for label_bucket in subset_buckets.values(): for item in label_bucket: if item: selected_items.setdefault((item.id, item.subset), item) self._selected_items = selected_items.values() yield from self._selected_items