# Copyright (C) 2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
import argparse
from collections import defaultdict
from random import Random
from typing import List, Mapping, Optional, Tuple
from datumaro.components.annotation import AnnotationType
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.extractor import DatasetItem, IExtractor, Transform
from datumaro.util import cast
[docs]class RandomSampler(Transform, CliPlugin):
    """
    Sampler that keeps no more than required number of items in the dataset.|n
    |n
    Notes:|n
    |s|s- Items are selected uniformly|n
    |s|s- Requesting a sample larger than the number of all images will|n
    |s|s|s|sreturn all images|n
    |n
    Example: select subset of 20 images randomly|n
    .. code-block::
    |s|s%(prog)s -k 20 |n
    |n
    Example: select subset of 20 images, modify only 'train' subset|n
    .. code-block::
    |s|s%(prog)s -k 20 -s train
    """
[docs]    @classmethod
    def build_cmdline_parser(cls, **kwargs):
        parser = super().build_cmdline_parser(**kwargs)
        parser.add_argument(
            "-k", "--count", type=int, required=True, help="Maximum number of items to sample"
        )
        parser.add_argument(
            "-s",
            "--subset",
            default=None,
            help="Limit changes to this subset (default: affect all dataset)",
        )
        parser.add_argument("--seed", type=int, help="Initial value for random number generator")
        return parser 
    def __init__(
        self,
        extractor: IExtractor,
        count: int,
        *,
        subset: Optional[str] = None,
        seed: Optional[int] = None,
    ):
        super().__init__(extractor)
        self._seed = seed
        self._count = count
        self._indices = None
        self._subset = subset
    def __iter__(self):
        if self._indices is None:
            rng = Random(self._seed)
            if self._subset:
                n = len(self._extractor.get_subset(self._subset))
            else:
                n = len(self._extractor)
            self._indices = rng.sample(range(n), min(self._count, n))
            self._indices.sort()
        idx_iter = iter(self._indices)
        try:
            next_pick = next(idx_iter)
        except StopIteration:
            if not self._subset:
                return
            else:
                next_pick = -1
        i = 0
        for item in self._extractor:
            if self._subset and self._subset != item.subset:
                yield item
            else:
                if i == next_pick:
                    yield item
                    try:
                        next_pick = next(idx_iter)
                    except StopIteration:
                        if self._subset:
                            next_pick = -1
                            continue
                        else:
                            return
                i += 1 
[docs]class LabelRandomSampler(Transform, CliPlugin):
    """
    Sampler that keeps at least the required number of annotations of
    each class in the dataset for each subset separately.|n
    |n
    Consider using the "stats" command to get class distribution in
    the dataset.|n
    |n
    Notes:|n
    |s|s- Items can contain annotations of several selected classes|n
    |s|s|s|s(e.g. 3 bounding boxes per image). The number of annotations in the|n
    |s|s|s|sresulting dataset varies between max(class counts) and sum(class counts)|n
    |s|s- If the input dataset does not has enough class annotations, the result|n
    |s|s|s|swill contain only what is available|n
    |s|s- Items are selected uniformly|n
    |s|s- For reasons above, the resulting class distribution in the dataset may|n
    |s|s|s|snot be the same as requested|n
    |s|s- The resulting dataset will only keep annotations for|n
    |s|s|s|sclasses with specified count > 0|n |n
    Example: select at least 5 annotations of each class randomly|n
    .. code-block::
    |s|s%(prog)s -k 5 |n |n
    Example: select at least 5 images with "cat" annotations and 3 "person"|n
    .. code-block::
    |s|s%(prog)s -l "cat:5" -l "person:3"
    """
    @staticmethod
    def _parse_label_count(s: str) -> Tuple[str, int]:
        label, count = s.split(":", maxsplit=1)
        count = cast(count, int, default=None)
        if not label:
            raise argparse.ArgumentError(None, "Class name cannot be empty")
        if count is None or count < 0:
            raise argparse.ArgumentError(None, f"Class '{label}' count is invalid")
        return label, count
[docs]    @classmethod
    def build_cmdline_parser(cls, **kwargs):
        parser = super().build_cmdline_parser(**kwargs)
        parser.add_argument(
            "-k",
            "--count",
            type=int,
            required=True,
            help="Minimum number of annotations of each class",
        )
        parser.add_argument(
            "-l",
            "--label",
            dest="label_counts",
            action="append",
            type=cls._parse_label_count,
            help="Minimum number of annotations of a specific class. "
            "Overrides the `-k/--count` setting for the class. "
            "The format is 'label_name:count' (repeatable)",
        )
        parser.add_argument("--seed", type=int, help="Initial value for random number generator")
        return parser 
    def __init__(
        self,
        extractor: IExtractor,
        *,
        count: Optional[int] = None,
        label_counts: Optional[Mapping[str, int]] = None,
        seed: Optional[int] = None,
    ):
        from datumaro.plugins.transforms import ProjectLabels
        count = count or 0
        label_counts = dict(label_counts or {})
        assert count or any(label_counts.values())
        new_labels = {}
        for label in extractor.categories()[AnnotationType.label]:
            label_count = label_counts.get(label.name, count)
            if label_count:
                new_labels[label.name] = label_count
        self._label_counts = {idx: count for idx, count in enumerate(new_labels.values())}
        super().__init__(ProjectLabels(extractor, new_labels.keys()))
        self._seed = seed
        # for repeated calls
        self._selected_items: List[DatasetItem] = None
    def __iter__(self):
        if self._selected_items is not None:
            yield from self._selected_items
            return
        # Uses the reservoir sampling algorithm for each class
        # https://en.wikipedia.org/wiki/Reservoir_sampling
        def _make_bucket():
            # label -> bucket
            return {label: [] for label in self._label_counts}
        buckets = defaultdict(_make_bucket)  # subset -> subset_buckets
        rng = Random(self._seed)
        for i, item in enumerate(self._extractor):
            labels = set(getattr(ann, "label", None) for ann in item.annotations)
            labels.discard(None)
            for label in labels:
                if len(buckets[item.subset][label]) < self._label_counts[label]:
                    buckets[item.subset][label].append(item)
                else:
                    j = rng.randint(1, i)
                    if j <= self._label_counts[label]:
                        buckets[item.subset][label][j - 1] = item
        selected_items = {}
        for subset_buckets in buckets.values():
            for label_bucket in subset_buckets.values():
                for item in label_bucket:
                    if item:
                        selected_items.setdefault((item.id, item.subset), item)
        self._selected_items = selected_items.values()
        yield from self._selected_items