# Copyright (C) 2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
from collections import defaultdict
from random import Random
from typing import List, Mapping, Optional, Tuple
import argparse
from datumaro.components.annotation import AnnotationType
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.extractor import DatasetItem, IExtractor, Transform
from datumaro.util import cast
[docs]class RandomSampler(Transform, CliPlugin):
    """
    Sampler that keeps no more than required number of items in the dataset.|n
    |n
    Notes:|n
    - Items are selected uniformly|n
    - Requesting a sample larger than the number of all images will
    return all images|n
    |n
    Example: select subset of 20 images randomly|n
    |s|s%(prog)s -k 20 |n
    Example: select subset of 20 images, modify only 'train' subset|n
    |s|s%(prog)s -k 20 -s train
    """
[docs]    @classmethod
    def build_cmdline_parser(cls, **kwargs):
        parser = super().build_cmdline_parser(**kwargs)
        parser.add_argument('-k', '--count', type=int, required=True,
            help="Maximum number of items to sample")
        parser.add_argument('-s', '--subset', default=None,
            help="Limit changes to this subset (default: affect all dataset)")
        parser.add_argument('--seed', type=int,
            help="Initial value for random number generator")
        return parser 
    def __init__(self, extractor: IExtractor, count: int, *,
            subset: Optional[str] = None, seed: Optional[int] = None):
        super().__init__(extractor)
        self._seed = seed
        self._count = count
        self._indices = None
        self._subset = subset
    def __iter__(self):
        if self._indices is None:
            rng = Random(self._seed)
            if self._subset:
                n = len(self._extractor.get_subset(self._subset))
            else:
                n = len(self._extractor)
            self._indices = rng.sample(range(n), min(self._count, n))
            self._indices.sort()
        idx_iter = iter(self._indices)
        try:
            next_pick = next(idx_iter)
        except StopIteration:
            if not self._subset:
                return
            else:
                next_pick = -1
        i = 0
        for item in self._extractor:
            if self._subset and self._subset != item.subset:
                yield item
            else:
                if i == next_pick:
                    yield item
                    try:
                        next_pick = next(idx_iter)
                    except StopIteration:
                        if self._subset:
                            next_pick = -1
                            continue
                        else:
                            return
                i += 1 
[docs]class LabelRandomSampler(Transform, CliPlugin):
    """
    Sampler that keeps at least the required number of annotations of
    each class in the dataset for each subset separately.|n
    |n
    Consider using the "stats" command to get class distribution in
    the dataset.|n
    |n
    Notes:|n
    - Items can contain annotations of several selected classes
    (e.g. 3 bounding boxes per image). The number of annotations in the
    resulting dataset varies between max(class counts) and sum(class counts)|n
    - If the input dataset does not has enough class annotations, the result
    will contain only what is available|n
    - Items are selected uniformly|n
    - For reasons above, the resulting class distribution in the dataset may
    not be the same as requested|n
    - The resulting dataset will only keep annotations for
    classes with specified count > 0|n
    |n
    Example: select at least 5 annotations of each class randomly|n
    |s|s%(prog)s -k 5 |n
    Example: select at least 5 images with "cat" annotations and 3 "person"|n
    |s|s%(prog)s -l "cat:5" -l "person:3"
    """
    @staticmethod
    def _parse_label_count(s: str) -> Tuple[str, int]:
        label, count = s.split(':', maxsplit=1)
        count = cast(count, int, default=None)
        if not label:
            raise argparse.ArgumentError(None, "Class name cannot be empty")
        if count is None or count < 0:
            raise argparse.ArgumentError(None,
                f"Class '{label}' count is invalid")
        return label, count
[docs]    @classmethod
    def build_cmdline_parser(cls, **kwargs):
        parser = super().build_cmdline_parser(**kwargs)
        parser.add_argument('-k', '--count', type=int, required=True,
            help="Minimum number of annotations of each class")
        parser.add_argument('-l', '--label', dest='label_counts',
            action='append', type=cls._parse_label_count,
            help="Minimum number of annotations of a specific class. "
                "Overrides the `-k/--count` setting for the class. "
                "The format is 'label_name:count' (repeatable)")
        parser.add_argument('--seed', type=int,
            help="Initial value for random number generator")
        return parser 
    def __init__(self, extractor: IExtractor, *,
            count: Optional[int] = None,
            label_counts: Optional[Mapping[str, int]] = None,
            seed: Optional[int] = None):
        from datumaro.plugins.transforms import ProjectLabels
        count = count or 0
        label_counts = dict(label_counts or {})
        assert count or any(label_counts.values())
        new_labels = {}
        for label in extractor.categories()[AnnotationType.label]:
            label_count = label_counts.get(label.name, count)
            if label_count:
                new_labels[label.name] = label_count
        self._label_counts = { idx: count
            for idx, count in enumerate(new_labels.values()) }
        super().__init__(ProjectLabels(extractor, new_labels.keys()))
        self._seed = seed
        # for repeated calls
        self._selected_items: List[DatasetItem] = None
    def __iter__(self):
        if self._selected_items is not None:
            yield from self._selected_items
            return
        # Uses the reservoir sampling algorithm for each class
        # https://en.wikipedia.org/wiki/Reservoir_sampling
        def _make_bucket():
            # label -> bucket
            return { label: [] for label in self._label_counts }
        buckets = defaultdict(_make_bucket) # subset -> subset_buckets
        rng = Random(self._seed)
        for i, item in enumerate(self._extractor):
            labels = set(getattr(ann, 'label', None)
                for ann in item.annotations)
            labels.discard(None)
            for label in labels:
                if len(buckets[item.subset][label]) < self._label_counts[label]:
                    buckets[item.subset][label].append(item)
                else:
                    j = rng.randint(1, i)
                    if j <= self._label_counts[label]:
                        buckets[item.subset][label][j - 1] = item
        selected_items = {}
        for subset_buckets in buckets.values():
            for label_bucket in subset_buckets.values():
                for item in label_bucket:
                    if item:
                        selected_items.setdefault((item.id, item.subset), item)
        self._selected_items = selected_items.values()
        yield from self._selected_items