# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import hashlib
import logging as log
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from unittest import TestCase
import attr
import cv2
import numpy as np
from attr import attrib, attrs
from datumaro.components.annotation import (
    Annotation,
    AnnotationType,
    Bbox,
    Label,
    LabelCategories,
    MaskCategories,
    PointsCategories,
)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset import Dataset, DatasetItemStorage, IDataset
from datumaro.components.errors import (
    AnnotationsTooCloseError,
    ConflictingCategoriesError,
    DatasetMergeError,
    FailedAttrVotingError,
    FailedLabelVotingError,
    MediaTypeError,
    MismatchingAttributesError,
    MismatchingImageInfoError,
    MismatchingMediaError,
    MismatchingMediaPathError,
    NoMatchingAnnError,
    NoMatchingItemError,
    VideoMergeError,
    WrongGroupError,
)
from datumaro.components.extractor import CategoriesInfo, DatasetItem
from datumaro.components.media import Image, MediaElement, MultiframeImage, PointCloud, Video
from datumaro.util import filter_dict, find
from datumaro.util.annotation_util import (
    OKS,
    approximate_line,
    bbox_iou,
    find_instances,
    max_bbox,
    mean_bbox,
    segment_iou,
)
from datumaro.util.attrs_util import default_if_none, ensure_cls
[docs]def match_annotations_equal(a, b):
    matches = []
    a_unmatched = a[:]
    b_unmatched = b[:]
    for a_ann in a:
        for b_ann in b_unmatched:
            if a_ann != b_ann:
                continue
            matches.append((a_ann, b_ann))
            a_unmatched.remove(a_ann)
            b_unmatched.remove(b_ann)
            break
    return matches, a_unmatched, b_unmatched 
[docs]def merge_annotations_equal(a, b):
    matches, a_unmatched, b_unmatched = match_annotations_equal(a, b)
    return [ann_a for (ann_a, _) in matches] + a_unmatched + b_unmatched 
[docs]def merge_categories(sources):
    categories = {}
    for source_idx, source in enumerate(sources):
        for cat_type, source_cat in source.items():
            existing_cat = categories.setdefault(cat_type, source_cat)
            if existing_cat != source_cat and len(source_cat) != 0:
                if len(existing_cat) == 0:
                    categories[cat_type] = source_cat
                else:
                    raise ConflictingCategoriesError(
                        "Merging of datasets with different categories is "
                        "only allowed in 'merge' command.",
                        sources=list(range(source_idx)),
                    )
    return categories 
[docs]class MergingStrategy(CliPlugin):
[docs]    @classmethod
    def merge(cls, sources, **options):
        instance = cls(**options)
        return instance(sources) 
[docs]    def __init__(self, **options):
        super().__init__(**options)
        self.__dict__["_sources"] = None 
    def __call__(self, sources):
        raise NotImplementedError() 
[docs]class ExactMerge:
    """
    Merges several datasets using the "simple" algorithm:
        - items are matched by (id, subset) pairs
        - matching items share the media info available:
            - nothing + nothing = nothing
            - nothing + something = something
            - something A + something B = conflict
        - annotations are matched by value and shared
        - in case of conflicts, throws an error
    """
[docs]    @classmethod
    def merge(cls, *sources: IDataset) -> DatasetItemStorage:
        items = DatasetItemStorage()
        for source_idx, source in enumerate(sources):
            for item in source:
                existing_item = items.get(item.id, item.subset)
                if existing_item is not None:
                    try:
                        item = cls._merge_items(existing_item, item)
                    except DatasetMergeError as e:
                        e.sources = set(range(source_idx))
                        raise e
                items.put(item)
        return items 
    @classmethod
    def _merge_items(cls, existing_item: DatasetItem, current_item: DatasetItem) -> DatasetItem:
        return existing_item.wrap(
            media=cls._merge_media(existing_item, current_item),
            attributes=cls._merge_attrs(
                existing_item.attributes,
                current_item.attributes,
                item_id=(existing_item.id, existing_item.subset),
            ),
            annotations=cls._merge_anno(existing_item.annotations, current_item.annotations),
        )
    @staticmethod
    def _merge_attrs(a: Dict[str, Any], b: Dict[str, Any], item_id: Tuple[str, str]) -> Dict:
        merged = {}
        for name in a.keys() | b.keys():
            a_val = a.get(name, None)
            b_val = b.get(name, None)
            if name not in a:
                m_val = b_val
            elif name not in b:
                m_val = a_val
            elif a_val != b_val:
                raise MismatchingAttributesError(item_id, name, a_val, b_val)
            else:
                m_val = a_val
            merged[name] = m_val
        return merged
    @classmethod
    def _merge_media(
        cls, item_a: DatasetItem, item_b: DatasetItem
    ) -> Union[Image, PointCloud, Video]:
        if (not item_a.media or isinstance(item_a.media, Image)) and (
            not item_b.media or isinstance(item_b.media, Image)
        ):
            media = cls._merge_images(item_a, item_b)
        elif (not item_a.media or isinstance(item_a.media, PointCloud)) and (
            not item_b.media or isinstance(item_b.media, PointCloud)
        ):
            media = cls._merge_point_clouds(item_a, item_b)
        elif (not item_a.media or isinstance(item_a.media, Video)) and (
            not item_b.media or isinstance(item_b.media, Video)
        ):
            media = cls._merge_videos(item_a, item_b)
        elif (not item_a.media or isinstance(item_a.media, MultiframeImage)) and (
            not item_b.media or isinstance(item_b.media, MultiframeImage)
        ):
            media = cls._merge_multiframe_images(item_a, item_b)
        elif (not item_a.media or isinstance(item_a.media, MediaElement)) and (
            not item_b.media or isinstance(item_b.media, MediaElement)
        ):
            if isinstance(item_a.media, MediaElement) and isinstance(item_b.media, MediaElement):
                if (
                    item_a.media.path
                    and item_b.media.path
                    and item_a.media.path != item_b.media.path
                ):
                    raise MismatchingMediaPathError(
                        (item_a.id, item_a.subset), item_a.media.path, item_b.media.path
                    )
                if item_a.media.path:
                    media = item_a.media
                else:
                    media = item_b.media
            elif isinstance(item_a.media, MediaElement):
                media = item_a.media
            else:
                media = item_b.media
        else:
            raise MismatchingMediaError((item_a.id, item_a.subset), item_a.media, item_b.media)
        return media
    @staticmethod
    def _merge_images(item_a: DatasetItem, item_b: DatasetItem) -> Image:
        media = None
        if isinstance(item_a.media, Image) and isinstance(item_b.media, Image):
            if (
                item_a.media.path
                and item_b.media.path
                and item_a.media.path != item_b.media.path
                and item_a.media.has_data is item_b.media.has_data
            ):
                # We use has_data as a replacement for path existence check
                # - If only one image has data, we'll use it. The other
                #   one is just a path metainfo, which is not significant
                #   in this case.
                # - If both images have data or both don't, we need
                #   to compare paths.
                #
                # Different paths can aclually point to the same file,
                # but it's not the case we'd like to allow here to be
                # a "simple" merging strategy used for extractor joining
                raise MismatchingMediaPathError(
                    (item_a.id, item_a.subset), item_a.media.path, item_b.media.path
                )
            if (
                item_a.media.has_size
                and item_b.media.has_size
                and item_a.media.size != item_b.media.size
            ):
                raise MismatchingImageInfoError(
                    (item_a.id, item_a.subset), item_a.media.size, item_b.media.size
                )
            # Avoid direct comparison here for better performance
            # If there are 2 "data-only" images, they won't be compared and
            # we just use the first one
            if item_a.media.has_data:
                media = item_a.media
            elif item_b.media.has_data:
                media = item_b.media
            elif item_a.media.path:
                media = item_a.media
            elif item_b.media.path:
                media = item_b.media
            elif item_a.media.has_size:
                media = item_a.media
            elif item_b.media.has_size:
                media = item_b.media
            else:
                assert False, "Unknown image field combination"
            if not media.has_data or not media.has_size:
                if item_a.media._size:
                    media._size = item_a.media._size
                elif item_b.media._size:
                    media._size = item_b.media._size
        elif isinstance(item_a.media, Image):
            media = item_a.media
        else:
            media = item_b.media
        return media
    @staticmethod
    def _merge_point_clouds(item_a: DatasetItem, item_b: DatasetItem) -> PointCloud:
        media = None
        if isinstance(item_a.media, PointCloud) and isinstance(item_b.media, PointCloud):
            if item_a.media.path and item_b.media.path and item_a.media.path != item_b.media.path:
                raise MismatchingMediaPathError(
                    (item_a.id, item_a.subset), item_a.media.path, item_b.media.path
                )
            if item_a.media.path or item_a.media.extra_images:
                media = item_a.media
                if item_b.media.extra_images:
                    for image in item_b.media.extra_images:
                        if image not in media.extra_images:
                            media.extra_images.append(image)
            else:
                media = item_b.media
                if item_a.media.extra_images:
                    for image in item_a.media.extra_images:
                        if image not in media.extra_images:
                            media.extra_images.append(image)
        elif isinstance(item_a.media, PointCloud):
            media = item_a.media
        else:
            media = item_b.media
        return media
    @staticmethod
    def _merge_videos(item_a: DatasetItem, item_b: DatasetItem) -> Video:
        media = None
        if isinstance(item_a.media, Video) and isinstance(item_b.media, Video):
            if (
                item_a.media.path is not item_b.media.path
                or item_a.media._start_frame is not item_b.media._start_frame
                or item_a.media._end_frame is not item_b.media._end_frame
                or item_a.media._step is not item_b.media._step
            ):
                raise VideoMergeError(item_a.id)
            media = item_a.media
        elif isinstance(item_a.media, Video):
            media = item_a.media
        else:
            media = item_b.media
        return media
    @staticmethod
    def _merge_multiframe_images(item_a: DatasetItem, item_b: DatasetItem) -> MultiframeImage:
        media = None
        if isinstance(item_a.media, MultiframeImage) and isinstance(item_b.media, MultiframeImage):
            if item_a.media.path and item_b.media.path and item_a.media.path != item_b.media.path:
                raise MismatchingMediaPathError(
                    (item_a.id, item_a.subset), item_a.media.path, item_b.media.path
                )
            if item_a.media.path or item_a.media.data:
                media = item_a.media
                if item_b.media.data:
                    for image in item_b.media.data:
                        if image not in media.data:
                            media.data.append(image)
            else:
                media = item_b.media
                if item_a.media.data:
                    for image in item_a.media.data:
                        if image not in media.data:
                            media.data.append(image)
        elif isinstance(item_a.media, MultiframeImage):
            media = item_a.media
        else:
            media = item_b.media
        return media
    @staticmethod
    def _merge_anno(a: Iterable[Annotation], b: Iterable[Annotation]) -> List[Annotation]:
        return merge_annotations_equal(a, b)
[docs]    @staticmethod
    def merge_categories(sources: Iterable[IDataset]) -> CategoriesInfo:
        return merge_categories(sources) 
 
[docs]@attrs
class IntersectMerge(MergingStrategy):
[docs]    @attrs(repr_ns="IntersectMerge", kw_only=True)
    class Conf:
        pairwise_dist = attrib(converter=float, default=0.5)
        sigma = attrib(converter=list, factory=list)
        output_conf_thresh = attrib(converter=float, default=0)
        quorum = attrib(converter=int, default=0)
        ignored_attributes = attrib(converter=set, factory=set)
        def _groups_converter(value):
            result = []
            for group in value:
                rg = set()
                for label in group:
                    optional = label.endswith("?")
                    name = label if not optional else label[:-1]
                    rg.add((name, optional))
                result.append(rg)
            return result
        groups = attrib(converter=_groups_converter, factory=list)
        close_distance = attrib(converter=float, default=0.75) 
    conf = attrib(converter=ensure_cls(Conf), factory=Conf)
    # Error trackers:
    errors = attrib(factory=list, init=False)
[docs]    def add_item_error(self, error, *args, **kwargs):
        self.errors.append(error(self._item_id, *args, **kwargs)) 
    # Indexes:
    _dataset_map = attrib(init=False)  # id(dataset) -> (dataset, index)
    _item_map = attrib(init=False)  # id(item) -> (item, id(dataset))
    _ann_map = attrib(init=False)  # id(ann) -> (ann, id(item))
    _item_id = attrib(init=False)
    _item = attrib(init=False)
    # Misc.
    _categories = attrib(init=False)  # merged categories
    def __call__(self, datasets):
        self._categories = self._merge_categories([d.categories() for d in datasets])
        merged = Dataset(
            categories=self._categories, media_type=ExactMerge.merge_media_types(datasets)
        )
        self._check_groups_definition()
        item_matches, item_map = self.match_items(datasets)
        self._item_map = item_map
        self._dataset_map = {id(d): (d, i) for i, d in enumerate(datasets)}
        for item_id, items in item_matches.items():
            self._item_id = item_id
            if len(items) < len(datasets):
                missing_sources = set(id(s) for s in datasets) - set(items)
                missing_sources = [self._dataset_map[s][1] for s in missing_sources]
                self.add_item_error(NoMatchingItemError, sources=missing_sources)
            merged.put(self.merge_items(items))
        return merged
[docs]    def get_ann_source(self, ann_id):
        return self._item_map[self._ann_map[ann_id][1]][1] 
[docs]    def merge_items(self, items):
        self._item = next(iter(items.values()))
        self._ann_map = {}
        sources = []
        for item in items.values():
            self._ann_map.update({id(a): (a, id(item)) for a in item.annotations})
            sources.append(item.annotations)
        log.debug(
            "Merging item %s: source annotations %s" % (self._item_id, list(map(len, sources)))
        )
        annotations = self.merge_annotations(sources)
        annotations = [
            a for a in annotations if self.conf.output_conf_thresh <= a.attributes.get("score", 1)
        ]
        return self._item.wrap(annotations=annotations) 
[docs]    def merge_annotations(self, sources):
        self._make_mergers(sources)
        clusters = self._match_annotations(sources)
        joined_clusters = sum(clusters.values(), [])
        group_map = self._find_cluster_groups(joined_clusters)
        annotations = []
        for t, clusters in clusters.items():
            for cluster in clusters:
                self._check_cluster_sources(cluster)
            merged_clusters = self._merge_clusters(t, clusters)
            for merged_ann, cluster in zip(merged_clusters, clusters):
                attributes = self._find_cluster_attrs(cluster, merged_ann)
                attributes = {
                    k: v for k, v in attributes.items() if k not in self.conf.ignored_attributes
                }
                attributes.update(merged_ann.attributes)
                merged_ann.attributes = attributes
                new_group_id = find(enumerate(group_map), lambda e: id(cluster) in e[1][0])
                if new_group_id is None:
                    new_group_id = 0
                else:
                    new_group_id = new_group_id[0] + 1
                merged_ann.group = new_group_id
            if self.conf.close_distance:
                self._check_annotation_distance(t, merged_clusters)
            annotations += merged_clusters
        if self.conf.groups:
            self._check_groups(annotations)
        return annotations 
[docs]    @staticmethod
    def match_items(datasets):
        item_ids = set((item.id, item.subset) for d in datasets for item in d)
        item_map = {}  # id(item) -> (item, id(dataset))
        matches = OrderedDict()
        for (item_id, item_subset) in sorted(item_ids, key=lambda e: e[0]):
            items = {}
            for d in datasets:
                item = d.get(item_id, subset=item_subset)
                if item:
                    items[id(d)] = item
                    item_map[id(item)] = (item, id(d))
            matches[(item_id, item_subset)] = items
        return matches, item_map 
    def _merge_label_categories(self, sources):
        same = True
        common = None
        for src_categories in sources:
            src_cat = src_categories.get(AnnotationType.label)
            if common is None:
                common = src_cat
            elif common != src_cat:
                same = False
                break
        if same:
            return common
        dst_cat = LabelCategories()
        for src_id, src_categories in enumerate(sources):
            src_cat = src_categories.get(AnnotationType.label)
            if src_cat is None:
                continue
            for src_label in src_cat.items:
                dst_label = dst_cat.find(src_label.name, src_label.parent)[1]
                if dst_label is not None:
                    if dst_label != src_label:
                        if (
                            src_label.parent
                            and dst_label.parent
                            and src_label.parent != dst_label.parent
                        ):
                            raise ConflictingCategoriesError(
                                "Can't merge label category %s (from #%s): "
                                "parent label conflict: %s vs. %s"
                                % (src_label.name, src_id, src_label.parent, dst_label.parent),
                                sources=list(range(src_id)),
                            )
                        dst_label.parent = dst_label.parent or src_label.parent
                        dst_label.attributes |= src_label.attributes
                    else:
                        pass
                else:
                    dst_cat.add(src_label.name, src_label.parent, src_label.attributes)
        return dst_cat
    def _merge_point_categories(self, sources, label_cat):
        dst_point_cat = PointsCategories()
        for src_id, src_categories in enumerate(sources):
            src_label_cat = src_categories.get(AnnotationType.label)
            src_point_cat = src_categories.get(AnnotationType.points)
            if src_label_cat is None or src_point_cat is None:
                continue
            for src_label_id, src_cat in src_point_cat.items.items():
                src_label = src_label_cat.items[src_label_id].name
                src_parent_label = src_label_cat.items[src_label_id].parent
                dst_label_id = label_cat.find(src_label, src_parent_label)[0]
                dst_cat = dst_point_cat.items.get(dst_label_id)
                if dst_cat is not None:
                    if dst_cat != src_cat:
                        raise ConflictingCategoriesError(
                            "Can't merge point category for label "
                            "%s (from #%s): %s vs. %s" % (src_label, src_id, src_cat, dst_cat),
                            sources=list(range(src_id)),
                        )
                    else:
                        pass
                else:
                    dst_point_cat.add(dst_label_id, src_cat.labels, src_cat.joints)
        if len(dst_point_cat.items) == 0:
            return None
        return dst_point_cat
    def _merge_mask_categories(self, sources, label_cat):
        dst_mask_cat = MaskCategories()
        for src_id, src_categories in enumerate(sources):
            src_label_cat = src_categories.get(AnnotationType.label)
            src_mask_cat = src_categories.get(AnnotationType.mask)
            if src_label_cat is None or src_mask_cat is None:
                continue
            for src_label_id, src_cat in src_mask_cat.colormap.items():
                src_label = src_label_cat.items[src_label_id].name
                src_parent_label = src_label_cat.items[src_label_id].parent
                dst_label_id = label_cat.find(src_label, src_parent_label)[0]
                dst_cat = dst_mask_cat.colormap.get(dst_label_id)
                if dst_cat is not None:
                    if dst_cat != src_cat:
                        raise ConflictingCategoriesError(
                            "Can't merge mask category for label "
                            "%s (from #%s): %s vs. %s" % (src_label, src_id, src_cat, dst_cat),
                            sources=list(range(src_id)),
                        )
                    else:
                        pass
                else:
                    dst_mask_cat.colormap[dst_label_id] = src_cat
        if len(dst_mask_cat.colormap) == 0:
            return None
        return dst_mask_cat
    def _merge_categories(self, sources):
        dst_categories = {}
        label_cat = self._merge_label_categories(sources)
        if label_cat is None:
            label_cat = LabelCategories()
        dst_categories[AnnotationType.label] = label_cat
        points_cat = self._merge_point_categories(sources, label_cat)
        if points_cat is not None:
            dst_categories[AnnotationType.points] = points_cat
        mask_cat = self._merge_mask_categories(sources, label_cat)
        if mask_cat is not None:
            dst_categories[AnnotationType.mask] = mask_cat
        return dst_categories
    def _match_annotations(self, sources):
        all_by_type = {}
        for s in sources:
            src_by_type = {}
            for a in s:
                src_by_type.setdefault(a.type, []).append(a)
            for k, v in src_by_type.items():
                all_by_type.setdefault(k, []).append(v)
        clusters = {}
        for k, v in all_by_type.items():
            clusters.setdefault(k, []).extend(self._match_ann_type(k, v))
        return clusters
    def _make_mergers(self, sources):
        def _make(c, **kwargs):
            kwargs.update(attr.asdict(self.conf))
            fields = attr.fields_dict(c)
            return c(**{k: v for k, v in kwargs.items() if k in fields}, context=self)
        def _for_type(t, **kwargs):
            if t is AnnotationType.label:
                return _make(LabelMerger, **kwargs)
            elif t is AnnotationType.bbox:
                return _make(BboxMerger, **kwargs)
            elif t is AnnotationType.mask:
                return _make(MaskMerger, **kwargs)
            elif t is AnnotationType.polygon:
                return _make(PolygonMerger, **kwargs)
            elif t is AnnotationType.polyline:
                return _make(LineMerger, **kwargs)
            elif t is AnnotationType.points:
                return _make(PointsMerger, **kwargs)
            elif t is AnnotationType.caption:
                return _make(CaptionsMerger, **kwargs)
            elif t is AnnotationType.cuboid_3d:
                return _make(Cuboid3dMerger, **kwargs)
            elif t is AnnotationType.super_resolution_annotation:
                return _make(ImageAnnotationMerger, **kwargs)
            elif t is AnnotationType.depth_annotation:
                return _make(ImageAnnotationMerger, **kwargs)
            elif t is AnnotationType.skeleton:
                # to do: add skeletons merge
                return _make(ImageAnnotationMerger, **kwargs)
            else:
                raise NotImplementedError("Type %s is not supported" % t)
        instance_map = {}
        for s in sources:
            s_instances = find_instances(s)
            for inst in s_instances:
                inst_bbox = max_bbox(
                    [
                        a
                        for a in inst
                        if a.type
                        in {AnnotationType.polygon, AnnotationType.mask, AnnotationType.bbox}
                    ]
                )
                for ann in inst:
                    instance_map[id(ann)] = [inst, inst_bbox]
        self._mergers = {t: _for_type(t, instance_map=instance_map) for t in AnnotationType}
    def _match_ann_type(self, t, sources):
        return self._mergers[t].match_annotations(sources)
    def _merge_clusters(self, t, clusters):
        return self._mergers[t].merge_clusters(clusters)
    @staticmethod
    def _find_cluster_groups(clusters):
        cluster_groups = []
        visited = set()
        for a_idx, cluster_a in enumerate(clusters):
            if a_idx in visited:
                continue
            visited.add(a_idx)
            cluster_group = {id(cluster_a)}
            # find segment groups in the cluster group
            a_groups = set(ann.group for ann in cluster_a)
            for cluster_b in clusters[a_idx + 1 :]:
                b_groups = set(ann.group for ann in cluster_b)
                if a_groups & b_groups:
                    a_groups |= b_groups
            # now we know all the segment groups in this cluster group
            # so we can find adjacent clusters
            for b_idx, cluster_b in enumerate(clusters[a_idx + 1 :]):
                b_idx = a_idx + 1 + b_idx
                b_groups = set(ann.group for ann in cluster_b)
                if a_groups & b_groups:
                    cluster_group.add(id(cluster_b))
                    visited.add(b_idx)
            if a_groups == {0}:
                continue  # skip annotations without a group
            cluster_groups.append((cluster_group, a_groups))
        return cluster_groups
    def _find_cluster_attrs(self, cluster, ann):
        quorum = self.conf.quorum or 0
        # TODO: when attribute types are implemented, add linear
        # interpolation for contiguous values
        attr_votes = {}  # name -> { value: score , ... }
        for s in cluster:
            for name, value in s.attributes.items():
                votes = attr_votes.get(name, {})
                votes[value] = 1 + votes.get(value, 0)
                attr_votes[name] = votes
        attributes = {}
        for name, votes in attr_votes.items():
            winner, count = max(votes.items(), key=lambda e: e[1])
            if count < quorum:
                if sum(votes.values()) < quorum:
                    # blame provokers
                    missing_sources = set(
                        self.get_ann_source(id(a))
                        for a in cluster
                        if s.attributes.get(name) == winner
                    )
                else:
                    # blame outliers
                    missing_sources = set(
                        self.get_ann_source(id(a))
                        for a in cluster
                        if s.attributes.get(name) != winner
                    )
                missing_sources = [self._dataset_map[s][1] for s in missing_sources]
                self.add_item_error(
                    FailedAttrVotingError, name, votes, ann, sources=missing_sources
                )
                continue
            attributes[name] = winner
        return attributes
    def _check_cluster_sources(self, cluster):
        if len(cluster) == len(self._dataset_map):
            return
        def _has_item(s):
            item = self._dataset_map[s][0].get(*self._item_id)
            if not item:
                return False
            if len(item.annotations) == 0:
                return False
            return True
        missing_sources = set(self._dataset_map) - set(self.get_ann_source(id(a)) for a in cluster)
        missing_sources = [self._dataset_map[s][1] for s in missing_sources if _has_item(s)]
        if missing_sources:
            self.add_item_error(NoMatchingAnnError, cluster[0], sources=missing_sources)
    def _check_annotation_distance(self, t, annotations):
        for a_idx, a_ann in enumerate(annotations):
            for b_ann in annotations[a_idx + 1 :]:
                d = self._mergers[t].distance(a_ann, b_ann)
                if self.conf.close_distance < d:
                    self.add_item_error(AnnotationsTooCloseError, a_ann, b_ann, d)
    def _check_groups(self, annotations):
        check_groups = []
        for check_group_raw in self.conf.groups:
            check_group = set(l[0] for l in check_group_raw)
            optional = set(l[0] for l in check_group_raw if l[1])
            check_groups.append((check_group, optional))
        def _check_group(group_labels, group):
            for check_group, optional in check_groups:
                common = check_group & group_labels
                real_miss = check_group - common - optional
                extra = group_labels - check_group
                if common and (extra or real_miss):
                    self.add_item_error(WrongGroupError, group_labels, check_group, group)
                    break
        groups = find_instances(annotations)
        for group in groups:
            group_labels = set()
            for ann in group:
                if not hasattr(ann, "label"):
                    continue
                label = self._get_label_name(ann.label)
                if ann.group:
                    group_labels.add(label)
                else:
                    _check_group({label}, [ann])
            if not group_labels:
                continue
            _check_group(group_labels, group)
    def _get_label_name(self, label_id):
        if label_id is None:
            return None
        return self._categories[AnnotationType.label].items[label_id].name
    def _get_label_id(self, label, parent=""):
        if label is not None:
            return self._categories[AnnotationType.label].find(label, parent)[0]
        return None
    def _get_src_label_name(self, ann, label_id):
        if label_id is None:
            return None
        item_id = self._ann_map[id(ann)][1]
        dataset_id = self._item_map[item_id][1]
        return (
            self._dataset_map[dataset_id][0].categories()[AnnotationType.label].items[label_id].name
        )
    def _get_any_label_name(self, ann, label_id):
        if label_id is None:
            return None
        try:
            return self._get_src_label_name(ann, label_id)
        except KeyError:
            return self._get_label_name(label_id)
    def _check_groups_definition(self):
        for group in self.conf.groups:
            for label, _ in group:
                _, entry = self._categories[AnnotationType.label].find(label)
                if entry is None:
                    raise ValueError(
                        "Datasets do not contain "
                        "label '%s', available labels %s"
                        % (label, [i.name for i in self._categories[AnnotationType.label].items])
                    ) 
[docs]@attrs(kw_only=True)
class AnnotationMatcher:
    _context: Optional[IntersectMerge] = attrib(default=None)
    def match_annotations(self, sources):
        raise NotImplementedError() 
[docs]@attrs
class LabelMatcher(AnnotationMatcher):
    def distance(self, a, b):
        a_label = self._context._get_any_label_name(a, a.label)
        b_label = self._context._get_any_label_name(b, b.label)
        return a_label == b_label
    def match_annotations(self, sources):
        return [sum(sources, [])] 
[docs]@attrs(kw_only=True)
class _ShapeMatcher(AnnotationMatcher):
    pairwise_dist = attrib(converter=float, default=0.9)
    cluster_dist = attrib(converter=float, default=-1.0)
[docs]    def match_annotations(self, sources):
        distance = self.distance
        label_matcher = self.label_matcher
        pairwise_dist = self.pairwise_dist
        cluster_dist = self.cluster_dist
        if cluster_dist < 0:
            cluster_dist = pairwise_dist
        id_segm = {id(a): (a, id(s)) for s in sources for a in s}
        def _is_close_enough(cluster, extra_id):
            # check if whole cluster IoU will not be broken
            # when this segment is added
            b = id_segm[extra_id][0]
            for a_id in cluster:
                a = id_segm[a_id][0]
                if distance(a, b) < cluster_dist:
                    return False
            return True
        def _has_same_source(cluster, extra_id):
            b = id_segm[extra_id][1]
            for a_id in cluster:
                a = id_segm[a_id][1]
                if a == b:
                    return True
            return False
        # match segments in sources, pairwise
        adjacent = {i: [] for i in id_segm}  # id(sgm) -> [id(adj_sgm1), ...]
        for a_idx, src_a in enumerate(sources):
            for src_b in sources[a_idx + 1 :]:
                matches, _, _, _ = match_segments(
                    src_a,
                    src_b,
                    dist_thresh=pairwise_dist,
                    distance=distance,
                    label_matcher=label_matcher,
                )
                for a, b in matches:
                    adjacent[id(a)].append(id(b))
        # join all segments into matching clusters
        clusters = []
        visited = set()
        for cluster_idx in adjacent:
            if cluster_idx in visited:
                continue
            cluster = set()
            to_visit = {cluster_idx}
            while to_visit:
                c = to_visit.pop()
                cluster.add(c)
                visited.add(c)
                for i in adjacent[c]:
                    if i in visited:
                        continue
                    if 0 < cluster_dist and not _is_close_enough(cluster, i):
                        continue
                    if _has_same_source(cluster, i):
                        continue
                    to_visit.add(i)
            clusters.append([id_segm[i][0] for i in cluster])
        return clusters 
[docs]    def distance(self, a, b):
        return segment_iou(a, b) 
[docs]    def label_matcher(self, a, b):
        a_label = self._context._get_any_label_name(a, a.label)
        b_label = self._context._get_any_label_name(b, b.label)
        return a_label == b_label  
[docs]@attrs
class BboxMatcher(_ShapeMatcher):
    pass 
[docs]@attrs
class PolygonMatcher(_ShapeMatcher):
    pass 
[docs]@attrs
class MaskMatcher(_ShapeMatcher):
    pass 
[docs]@attrs(kw_only=True)
class PointsMatcher(_ShapeMatcher):
    sigma: Optional[list] = attrib(default=None)
    instance_map = attrib(converter=dict)
[docs]    def distance(self, a, b):
        a_bbox = self.instance_map[id(a)][1]
        b_bbox = self.instance_map[id(b)][1]
        if bbox_iou(a_bbox, b_bbox) <= 0:
            return 0
        bbox = mean_bbox([a_bbox, b_bbox])
        return OKS(a, b, sigma=self.sigma, bbox=bbox)  
[docs]@attrs
class LineMatcher(_ShapeMatcher):
[docs]    def distance(self, a, b):
        # Compute inter-line area by using the Trapezoid formulae
        # https://en.wikipedia.org/wiki/Trapezoidal_rule
        # Normalize by common bbox and get the bbox fill ratio
        # Call this ratio the "distance"
        # The box area is an early-exit filter for non-intersected figures
        bbox = max_bbox([a, b])
        box_area = bbox[2] * bbox[3]
        if not box_area:
            return 1
        def _approx(line, segments):
            if len(line) // 2 != segments + 1:
                line = approximate_line(line, segments=segments)
            return np.reshape(line, (-1, 2))
        segments = max(len(a.points) // 2, len(b.points) // 2, 5) - 1
        a = _approx(a.points, segments)
        b = _approx(b.points, segments)
        dists = np.linalg.norm(a - b, axis=1)
        dists = dists[:-1] + dists[1:]
        a_steps = np.linalg.norm(a[1:] - a[:-1], axis=1)
        b_steps = np.linalg.norm(b[1:] - b[:-1], axis=1)
        # For the common bbox we can't use
        # - the AABB (axis-alinged bbox) of a point set
        # - the exterior of a point set
        # - the convex hull of a point set
        # because these soultions won't be correctly normalized.
        # The lines can have multiple self-intersections, which can give
        # the inter-line area more than internal area of the options above,
        # producing the value of the distance outside of the [0; 1] range.
        #
        # Instead, we can compute the upper boundary for the inter-line
        # area based on the maximum point distance and line length.
        max_area = np.max(dists) * max(np.sum(a_steps), np.sum(b_steps))
        area = np.dot(dists, a_steps + b_steps) * 0.5 * 0.5 / max(max_area, 1.0)
        return abs(1 - area)  
[docs]@attrs
class CaptionsMatcher(AnnotationMatcher):
[docs]    def match_annotations(self, sources):
        raise NotImplementedError()  
[docs]@attrs
class Cuboid3dMatcher(_ShapeMatcher):
[docs]    def distance(self, a, b):
        raise NotImplementedError()  
@attrs
class ImageAnnotationMatcher(AnnotationMatcher):
    def match_annotations(self, sources):
        raise NotImplementedError()
[docs]@attrs(kw_only=True)
class AnnotationMerger:
[docs]    def merge_clusters(self, clusters):
        raise NotImplementedError()  
[docs]@attrs(kw_only=True)
class LabelMerger(AnnotationMerger, LabelMatcher):
    quorum = attrib(converter=int, default=0)
[docs]    def merge_clusters(self, clusters):
        assert len(clusters) <= 1
        if len(clusters) == 0:
            return []
        votes = {}  # label -> score
        for ann in clusters[0]:
            label = self._context._get_src_label_name(ann, ann.label)
            votes[label] = 1 + votes.get(label, 0)
        merged = []
        for label, count in votes.items():
            if count < self.quorum:
                sources = set(
                    self.get_ann_source(id(a))
                    for a in clusters[0]
                    if label not in [self._context._get_src_label_name(l, l.label) for l in a]
                )
                sources = [self._context._dataset_map[s][1] for s in sources]
                self._context.add_item_error(FailedLabelVotingError, votes, sources=sources)
                continue
            merged.append(
                Label(
                    self._context._get_label_id(label),
                    attributes={"score": count / len(self._context._dataset_map)},
                )
            )
        return merged  
[docs]@attrs(kw_only=True)
class _ShapeMerger(AnnotationMerger, _ShapeMatcher):
    quorum = attrib(converter=int, default=0)
[docs]    def merge_clusters(self, clusters):
        return list(map(self.merge_cluster, clusters)) 
[docs]    def find_cluster_label(self, cluster):
        votes = {}
        for s in cluster:
            label = self._context._get_src_label_name(s, s.label)
            state = votes.setdefault(label, [0, 0])
            state[0] += s.attributes.get("score", 1.0)
            state[1] += 1
        label, (score, count) = max(votes.items(), key=lambda e: e[1][0])
        if count < self.quorum:
            self._context.add_item_error(FailedLabelVotingError, votes)
            label = None
        score = score / len(self._context._dataset_map)
        label = self._context._get_label_id(label)
        return label, score 
    @staticmethod
    def _merge_cluster_shape_mean_box_nearest(cluster):
        mbbox = Bbox(*mean_bbox(cluster))
        dist = (segment_iou(mbbox, s) for s in cluster)
        nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1])
        return cluster[nearest_pos]
[docs]    def merge_cluster_shape(self, cluster):
        shape = self._merge_cluster_shape_mean_box_nearest(cluster)
        shape_score = sum(max(0, self.distance(shape, s)) for s in cluster) / len(cluster)
        return shape, shape_score 
[docs]    def merge_cluster(self, cluster):
        label, label_score = self.find_cluster_label(cluster)
        shape, shape_score = self.merge_cluster_shape(cluster)
        shape.z_order = max(cluster, key=lambda a: a.z_order).z_order
        shape.label = label
        shape.attributes["score"] = label_score * shape_score if label is not None else shape_score
        return shape  
[docs]@attrs
class BboxMerger(_ShapeMerger, BboxMatcher):
    pass 
[docs]@attrs
class PolygonMerger(_ShapeMerger, PolygonMatcher):
    pass 
[docs]@attrs
class MaskMerger(_ShapeMerger, MaskMatcher):
    pass 
[docs]@attrs
class PointsMerger(_ShapeMerger, PointsMatcher):
    pass 
[docs]@attrs
class LineMerger(_ShapeMerger, LineMatcher):
    pass 
[docs]@attrs
class CaptionsMerger(AnnotationMerger, CaptionsMatcher):
    pass 
[docs]@attrs
class Cuboid3dMerger(_ShapeMerger, Cuboid3dMatcher):
    @staticmethod
    def _merge_cluster_shape_mean_box_nearest(cluster):
        raise NotImplementedError()
        # mbbox = Bbox(*mean_cuboid(cluster))
        # dist = (segment_iou(mbbox, s) for s in cluster)
        # nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1])
        # return cluster[nearest_pos]
[docs]    def merge_cluster(self, cluster):
        label, label_score = self.find_cluster_label(cluster)
        shape, shape_score = self.merge_cluster_shape(cluster)
        shape.label = label
        shape.attributes["score"] = label_score * shape_score if label is not None else shape_score
        return shape  
@attrs
class ImageAnnotationMerger(AnnotationMerger, ImageAnnotationMatcher):
    pass
[docs]def match_segments(
    a_segms,
    b_segms,
    distance=segment_iou,
    dist_thresh=1.0,
    label_matcher=lambda a, b: a.label == b.label,
):
    assert callable(distance), distance
    assert callable(label_matcher), label_matcher
    a_segms.sort(key=lambda ann: 1 - ann.attributes.get("score", 1))
    b_segms.sort(key=lambda ann: 1 - ann.attributes.get("score", 1))
    # a_matches: indices of b_segms matched to a bboxes
    # b_matches: indices of a_segms matched to b bboxes
    a_matches = -np.ones(len(a_segms), dtype=int)
    b_matches = -np.ones(len(b_segms), dtype=int)
    distances = np.array([[distance(a, b) for b in b_segms] for a in a_segms])
    # matches: boxes we succeeded to match completely
    # mispred: boxes we succeeded to match, having label mismatch
    matches = []
    mispred = []
    for a_idx, a_segm in enumerate(a_segms):
        if len(b_segms) == 0:
            break
        matched_b = -1
        max_dist = -1
        b_indices = np.argsort(
            [not label_matcher(a_segm, b_segm) for b_segm in b_segms], kind="stable"
        )  # prioritize those with same label, keep score order
        for b_idx in b_indices:
            if 0 <= b_matches[b_idx]:  # assign a_segm with max conf
                continue
            d = distances[a_idx, b_idx]
            if d < dist_thresh or d <= max_dist:
                continue
            max_dist = d
            matched_b = b_idx
        if matched_b < 0:
            continue
        a_matches[a_idx] = matched_b
        b_matches[matched_b] = a_idx
        b_segm = b_segms[matched_b]
        if label_matcher(a_segm, b_segm):
            matches.append((a_segm, b_segm))
        else:
            mispred.append((a_segm, b_segm))
    # *_umatched: boxes of (*) we failed to match
    a_unmatched = [a_segms[i] for i, m in enumerate(a_matches) if m < 0]
    b_unmatched = [b_segms[i] for i, m in enumerate(b_matches) if m < 0]
    return matches, mispred, a_unmatched, b_unmatched 
[docs]def mean_std(dataset: IDataset):
    counter = _MeanStdCounter()
    for item in dataset:
        counter.accumulate(item)
    return counter.get_result() 
[docs]class _MeanStdCounter:
    """
    Computes unbiased mean and std. dev. for dataset images, channel-wise.
    """
[docs]    def __init__(self):
        self._stats = {}  # (id, subset) -> (pixel count, mean vec, std vec) 
[docs]    def accumulate(self, item: DatasetItem):
        size = item.media.size
        if size is None:
            log.warning(
                "Item %s: can't detect image size, "
                "the image will be skipped from pixel statistics",
                item.id,
            )
            return
        count = np.prod(item.media.size)
        image = item.media.data
        if len(image.shape) == 2:
            image = image[:, :, np.newaxis]
        else:
            image = image[:, :, :3]
        # opencv is much faster than numpy here
        mean, std = cv2.meanStdDev(image.astype(np.double) / 255)
        self._stats[(item.id, item.subset)] = (count, mean, std) 
[docs]    def get_result(self) -> Tuple[Tuple[float, float, float], Tuple[float, float, float]]:
        n = len(self._stats)
        if n == 0:
            return [0, 0, 0], [0, 0, 0]
        counts = np.empty(n, dtype=np.uint32)
        stats = np.empty((n, 2, 3), dtype=np.double)
        for i, v in enumerate(self._stats.values()):
            counts[i] = v[0]
            stats[i][0] = v[1].reshape(-1)
            stats[i][1] = v[2].reshape(-1)
        mean = lambda i, s: s[i][0]
        var = lambda i, s: s[i][1]
        # make variance unbiased
        np.multiply(np.square(stats[:, 1]), (counts / (counts - 1))[:, np.newaxis], out=stats[:, 1])
        # Use an online algorithm to:
        # - handle different image sizes
        # - avoid cancellation problem
        _, mean, var = self._compute_stats(stats, counts, mean, var)
        return mean * 255, np.sqrt(var) * 255 
    # Implements online parallel computation of sample variance
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
[docs]    @staticmethod
    def _pairwise_stats(count_a, mean_a, var_a, count_b, mean_b, var_b):
        """
        Computes vector mean and variance.
        Needed do avoid catastrophic cancellation in floating point computations
        Returns:
            A tuple (total count, mean, variance)
        """
        # allow long arithmetics
        count_a = int(count_a)
        count_b = int(count_b)
        delta = mean_b - mean_a
        m_a = var_a * (count_a - 1)
        m_b = var_b * (count_b - 1)
        M2 = m_a + m_b + delta**2 * (count_a * count_b / (count_a + count_b))
        return (count_a + count_b, mean_a * 0.5 + mean_b * 0.5, M2 / (count_a + count_b - 1)) 
[docs]    @staticmethod
    def _compute_stats(stats, counts, mean_accessor, variance_accessor):
        """
        Recursively computes total count, mean and variance,
        does O(log(N)) calls.
        Args:
            stats: (float array of shape N, 2 * d, d = dimensions of values)
            count: (integer array of shape N)
            mean_accessor: (function(idx, stats)) to retrieve element mean
            variance_accessor: (function(idx, stats)) to retrieve element variance
        Returns:
            A tuple (total count, mean, variance)
        """
        m = mean_accessor
        v = variance_accessor
        n = len(stats)
        if n == 1:
            return counts[0], m(0, stats), v(0, stats)
        if n == 2:
            return __class__._pairwise_stats(
                counts[0], m(0, stats), v(0, stats), counts[1], m(1, stats), v(1, stats)
            )
        h = n // 2
        return __class__._pairwise_stats(
            *__class__._compute_stats(stats[:h], counts[:h], m, v),
            *__class__._compute_stats(stats[h:], counts[h:], m, v),
        )  
[docs]def compute_image_statistics(dataset: IDataset):
    stats = {
        "dataset": {
            "images count": 0,
            "unique images count": 0,
            "repeated images count": 0,
            "repeated images": [],  # [[id1, id2], [id3, id4, id5], ...]
        },
        "subsets": {},
    }
    stats_counter = _MeanStdCounter()
    unique_counter = _ItemMatcher()
    for item in dataset:
        stats_counter.accumulate(item)
        unique_counter.process_item(item)
    def _extractor_stats(subset_name, extractor):
        sub_counter = _MeanStdCounter()
        sub_counter._stats = {
            k: v
            for k, v in stats_counter._stats.items()
            if subset_name and k[1] == subset_name or not subset_name
        }
        available = len(sub_counter._stats) != 0
        stats = {
            "images count": len(extractor),
        }
        if available:
            mean, std = sub_counter.get_result()
            stats.update(
                {
                    "image mean": [float(v) for v in mean[::-1]],
                    "image std": [float(v) for v in std[::-1]],
                }
            )
        else:
            stats.update(
                {
                    "image mean": "n/a",
                    "image std": "n/a",
                }
            )
        return stats
    for subset_name in dataset.subsets():
        stats["subsets"][subset_name] = _extractor_stats(
            subset_name, dataset.get_subset(subset_name)
        )
    unique_items = unique_counter.get_result()
    repeated_items = [sorted(g) for g in unique_items.values() if 1 < len(g)]
    stats["dataset"].update(
        {
            "images count": len(dataset),
            "unique images count": len(unique_items),
            "repeated images count": len(repeated_items),
            "repeated images": repeated_items,  # [[id1, id2], [id3, id4, id5], ...]
        }
    )
    return stats 
[docs]def compute_ann_statistics(dataset: IDataset):
    labels = dataset.categories().get(AnnotationType.label, LabelCategories())
    def get_label(ann):
        return labels.items[ann.label].name if ann.label is not None else None
    stats = {
        "images count": 0,
        "annotations count": 0,
        "unannotated images count": 0,
        "unannotated images": [],
        "annotations by type": {
            t.name: {
                "count": 0,
            }
            for t in AnnotationType
        },
        "annotations": {},
    }
    by_type = stats["annotations by type"]
    attr_template = {
        "count": 0,
        "values count": 0,
        "values present": set(),
        "distribution": {},  # value -> (count, total%)
    }
    label_stat = {
        "count": 0,
        "distribution": {l.name: [0, 0] for l in labels.items},  # label -> (count, total%)
        "attributes": {},
    }
    stats["annotations"]["labels"] = label_stat
    segm_stat = {
        "avg. area": 0,
        "area distribution": [],  # a histogram with 10 bins
        # (min, min+10%), ..., (min+90%, max) -> (count, total%)
        "pixel distribution": {l.name: [0, 0] for l in labels.items},  # label -> (count, total%)
    }
    stats["annotations"]["segments"] = segm_stat
    segm_areas = []
    pixel_dist = segm_stat["pixel distribution"]
    total_pixels = 0
    for item in dataset:
        if len(item.annotations) == 0:
            stats["unannotated images"].append(item.id)
            continue
        for ann in item.annotations:
            by_type[ann.type.name]["count"] += 1
            if not hasattr(ann, "label") or ann.label is None:
                continue
            if ann.type in {AnnotationType.mask, AnnotationType.polygon, AnnotationType.bbox}:
                area = ann.get_area()
                segm_areas.append(area)
                pixel_dist[get_label(ann)][0] += int(area)
            label_stat["count"] += 1
            label_stat["distribution"][get_label(ann)][0] += 1
            for name, value in ann.attributes.items():
                if name.lower() in {"occluded", "visibility", "score", "id", "track_id"}:
                    continue
                attrs_stat = label_stat["attributes"].setdefault(name, deepcopy(attr_template))
                attrs_stat["count"] += 1
                attrs_stat["values present"].add(str(value))
                attrs_stat["distribution"].setdefault(str(value), [0, 0])[0] += 1
    stats["images count"] = len(dataset)
    stats["annotations count"] = sum(t["count"] for t in stats["annotations by type"].values())
    stats["unannotated images count"] = len(stats["unannotated images"])
    for label_info in label_stat["distribution"].values():
        label_info[1] = label_info[0] / (label_stat["count"] or 1)
    for label_attr in label_stat["attributes"].values():
        label_attr["values count"] = len(label_attr["values present"])
        label_attr["values present"] = sorted(label_attr["values present"])
        for attr_info in label_attr["distribution"].values():
            attr_info[1] = attr_info[0] / (label_attr["count"] or 1)
    # numpy.sum might be faster, but could overflow with large datasets.
    # Python's int can transparently mutate to be of indefinite precision (long)
    total_pixels = sum(int(a) for a in segm_areas)
    segm_stat["avg. area"] = total_pixels / (len(segm_areas) or 1.0)
    for label_info in segm_stat["pixel distribution"].values():
        label_info[1] = label_info[0] / (total_pixels or 1)
    if len(segm_areas) != 0:
        hist, bins = np.histogram(segm_areas)
        segm_stat["area distribution"] = [
            {
                "min": float(bin_min),
                "max": float(bin_max),
                "count": int(c),
                "percent": int(c) / len(segm_areas),
            }
            for c, (bin_min, bin_max) in zip(hist, zip(bins[:-1], bins[1:]))
        ]
    return stats 
[docs]@attrs
class DistanceComparator:
    iou_threshold = attrib(converter=float, default=0.5)
[docs]    def match_annotations(self, item_a, item_b):
        return {t: self._match_ann_type(t, item_a, item_b) for t in AnnotationType} 
    def _match_ann_type(self, t, *args):
        # pylint: disable=no-value-for-parameter
        if t == AnnotationType.label:
            return self.match_labels(*args)
        elif t == AnnotationType.bbox:
            return self.match_boxes(*args)
        elif t == AnnotationType.polygon:
            return self.match_polygons(*args)
        elif t == AnnotationType.mask:
            return self.match_masks(*args)
        elif t == AnnotationType.points:
            return self.match_points(*args)
        elif t == AnnotationType.polyline:
            return self.match_lines(*args)
        # pylint: enable=no-value-for-parameter
        else:
            raise NotImplementedError("Unexpected annotation type %s" % t)
    @staticmethod
    def _get_ann_type(t, item):
        return [a for a in item.annotations if a.type == t]
[docs]    def match_labels(self, item_a, item_b):
        a_labels = set(a.label for a in self._get_ann_type(AnnotationType.label, item_a))
        b_labels = set(a.label for a in self._get_ann_type(AnnotationType.label, item_b))
        matches = a_labels & b_labels
        a_unmatched = a_labels - b_labels
        b_unmatched = b_labels - a_labels
        return matches, a_unmatched, b_unmatched 
    def _match_segments(self, t, item_a, item_b):
        a_boxes = self._get_ann_type(t, item_a)
        b_boxes = self._get_ann_type(t, item_b)
        return match_segments(a_boxes, b_boxes, dist_thresh=self.iou_threshold)
[docs]    def match_polygons(self, item_a, item_b):
        return self._match_segments(AnnotationType.polygon, item_a, item_b) 
[docs]    def match_masks(self, item_a, item_b):
        return self._match_segments(AnnotationType.mask, item_a, item_b) 
[docs]    def match_boxes(self, item_a, item_b):
        return self._match_segments(AnnotationType.bbox, item_a, item_b) 
[docs]    def match_points(self, item_a, item_b):
        a_points = self._get_ann_type(AnnotationType.points, item_a)
        b_points = self._get_ann_type(AnnotationType.points, item_b)
        instance_map = {}
        for s in [item_a.annotations, item_b.annotations]:
            s_instances = find_instances(s)
            for inst in s_instances:
                inst_bbox = max_bbox(inst)
                for ann in inst:
                    instance_map[id(ann)] = [inst, inst_bbox]
        matcher = PointsMatcher(instance_map=instance_map)
        return match_segments(
            a_points, b_points, dist_thresh=self.iou_threshold, distance=matcher.distance
        ) 
[docs]    def match_lines(self, item_a, item_b):
        a_lines = self._get_ann_type(AnnotationType.polyline, item_a)
        b_lines = self._get_ann_type(AnnotationType.polyline, item_b)
        matcher = LineMatcher()
        return match_segments(
            a_lines, b_lines, dist_thresh=self.iou_threshold, distance=matcher.distance
        )  
[docs]def match_items_by_id(a: IDataset, b: IDataset):
    a_items = set((item.id, item.subset) for item in a)
    b_items = set((item.id, item.subset) for item in b)
    matches = a_items & b_items
    matches = [([m], [m]) for m in matches]
    a_unmatched = a_items - b_items
    b_unmatched = b_items - a_items
    return matches, a_unmatched, b_unmatched 
[docs]def match_items_by_image_hash(a: IDataset, b: IDataset):
    a_hash = find_unique_images(a)
    b_hash = find_unique_images(b)
    a_items = set(a_hash)
    b_items = set(b_hash)
    matches = a_items & b_items
    a_unmatched = a_items - b_items
    b_unmatched = b_items - a_items
    matches = [(a_hash[h], b_hash[h]) for h in matches]
    a_unmatched = set(i for h in a_unmatched for i in a_hash[h])
    b_unmatched = set(i for h in b_unmatched for i in b_hash[h])
    return matches, a_unmatched, b_unmatched 
class _ItemMatcher:
    @staticmethod
    def _default_item_hash(item: DatasetItem):
        if not item.media or not item.media.has_data:
            if item.media and item.media.path:
                return hash(item.media.path)
            log.warning(
                "Item (%s, %s) has no image " "info, counted as unique", item.id, item.subset
            )
            return None
        # Disable B303:md5, because the hash is not used in a security context
        return hashlib.md5(item.media.data.tobytes()).hexdigest()  # nosec
    def __init__(self, item_hash: Optional[Callable] = None):
        self._hash = item_hash or self._default_item_hash
        # hash -> [(id, subset), ...]
        self._unique: Dict[str, Set[Tuple[str, str]]] = {}
    def process_item(self, item: DatasetItem):
        h = self._hash(item)
        if h is None:
            h = str(id(item))  # anything unique
        self._unique.setdefault(h, set()).add((item.id, item.subset))
    def get_result(self):
        return self._unique
[docs]def find_unique_images(dataset: IDataset, item_hash: Optional[Callable] = None):
    matcher = _ItemMatcher(item_hash=item_hash)
    for item in dataset:
        matcher.process_item(item)
    return matcher.get_result() 
[docs]def match_classes(a: CategoriesInfo, b: CategoriesInfo):
    a_label_cat = a.get(AnnotationType.label, LabelCategories())
    b_label_cat = b.get(AnnotationType.label, LabelCategories())
    a_labels = set(c.name for c in a_label_cat)
    b_labels = set(c.name for c in b_label_cat)
    matches = a_labels & b_labels
    a_unmatched = a_labels - b_labels
    b_unmatched = b_labels - a_labels
    return matches, a_unmatched, b_unmatched 
[docs]@attrs
class ExactComparator:
    match_images: bool = attrib(kw_only=True, default=False)
    ignored_fields = attrib(kw_only=True, factory=set, validator=default_if_none(set))
    ignored_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set))
    ignored_item_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set))
    _test: TestCase = attrib(init=False)
    errors: list = attrib(init=False)
    def __attrs_post_init__(self):
        self._test = TestCase()
        self._test.maxDiff = None
    def _match_items(self, a, b):
        if self.match_images:
            return match_items_by_image_hash(a, b)
        else:
            return match_items_by_id(a, b)
    def _compare_categories(self, a, b):
        test = self._test
        errors = self.errors
        try:
            test.assertEqual(sorted(a, key=lambda t: t.value), sorted(b, key=lambda t: t.value))
        except AssertionError as e:
            errors.append({"type": "categories", "message": str(e)})
        if AnnotationType.label in a:
            try:
                test.assertEqual(
                    a[AnnotationType.label].items,
                    b[AnnotationType.label].items,
                )
            except AssertionError as e:
                errors.append({"type": "labels", "message": str(e)})
        if AnnotationType.mask in a:
            try:
                test.assertEqual(
                    a[AnnotationType.mask].colormap,
                    b[AnnotationType.mask].colormap,
                )
            except AssertionError as e:
                errors.append({"type": "colormap", "message": str(e)})
        if AnnotationType.points in a:
            try:
                test.assertEqual(
                    a[AnnotationType.points].items,
                    b[AnnotationType.points].items,
                )
            except AssertionError as e:
                errors.append({"type": "points", "message": str(e)})
    def _compare_annotations(self, a, b):
        ignored_fields = self.ignored_fields
        ignored_attrs = self.ignored_attrs
        a_fields = {k: None for k in a.as_dict() if k in ignored_fields}
        b_fields = {k: None for k in b.as_dict() if k in ignored_fields}
        if "attributes" not in ignored_fields:
            a_fields["attributes"] = filter_dict(a.attributes, ignored_attrs)
            b_fields["attributes"] = filter_dict(b.attributes, ignored_attrs)
        result = a.wrap(**a_fields) == b.wrap(**b_fields)
        return result
    def _compare_items(self, item_a, item_b):
        test = self._test
        a_id = (item_a.id, item_a.subset)
        b_id = (item_b.id, item_b.subset)
        matched = []
        unmatched = []
        errors = []
        try:
            test.assertEqual(
                filter_dict(item_a.attributes, self.ignored_item_attrs),
                filter_dict(item_b.attributes, self.ignored_item_attrs),
            )
        except AssertionError as e:
            errors.append({"type": "item_attr", "a_item": a_id, "b_item": b_id, "message": str(e)})
        b_annotations = item_b.annotations[:]
        for ann_a in item_a.annotations:
            ann_b_candidates = [x for x in item_b.annotations if x.type == ann_a.type]
            ann_b = find(
                enumerate(self._compare_annotations(ann_a, x) for x in ann_b_candidates),
                lambda x: x[1],
            )
            if ann_b is None:
                unmatched.append(
                    {
                        "item": a_id,
                        "source": "a",
                        "ann": str(ann_a),
                    }
                )
                continue
            else:
                ann_b = ann_b_candidates[ann_b[0]]
            b_annotations.remove(ann_b)  # avoid repeats
            matched.append({"a_item": a_id, "b_item": b_id, "a": str(ann_a), "b": str(ann_b)})
        for ann_b in b_annotations:
            unmatched.append({"item": b_id, "source": "b", "ann": str(ann_b)})
        return matched, unmatched, errors
[docs]    def compare_datasets(self, a, b):
        self.errors = []
        errors = self.errors
        self._compare_categories(a.categories(), b.categories())
        matched = []
        unmatched = []
        matches, a_unmatched, b_unmatched = self._match_items(a, b)
        if a.categories().get(AnnotationType.label) != b.categories().get(AnnotationType.label):
            return matched, unmatched, a_unmatched, b_unmatched, errors
        _dist = lambda s: len(s[1]) + len(s[2])
        for a_ids, b_ids in matches:
            # build distance matrix
            match_status = {}  # (a_id, b_id): [matched, unmatched, errors]
            a_matches = {a_id: None for a_id in a_ids}
            b_matches = {b_id: None for b_id in b_ids}
            for a_id in a_ids:
                item_a = a.get(*a_id)
                candidates = {}
                for b_id in b_ids:
                    item_b = b.get(*b_id)
                    i_m, i_um, i_err = self._compare_items(item_a, item_b)
                    candidates[b_id] = [i_m, i_um, i_err]
                    if len(i_um) == 0:
                        a_matches[a_id] = b_id
                        b_matches[b_id] = a_id
                        matched.extend(i_m)
                        errors.extend(i_err)
                        break
                match_status[a_id] = candidates
            # assign
            for a_id in a_ids:
                if len(b_ids) == 0:
                    break
                # find the closest, ignore already assigned
                matched_b = a_matches[a_id]
                if matched_b is not None:
                    continue
                min_dist = -1
                for b_id in b_ids:
                    if b_matches[b_id] is not None:
                        continue
                    d = _dist(match_status[a_id][b_id])
                    if d < min_dist and 0 <= min_dist:
                        continue
                    min_dist = d
                    matched_b = b_id
                if matched_b is None:
                    continue
                a_matches[a_id] = matched_b
                b_matches[matched_b] = a_id
                m = match_status[a_id][matched_b]
                matched.extend(m[0])
                unmatched.extend(m[1])
                errors.extend(m[2])
            a_unmatched |= set(a_id for a_id, m in a_matches.items() if not m)
            b_unmatched |= set(b_id for b_id, m in b_matches.items() if not m)
        return matched, unmatched, a_unmatched, b_unmatched, errors