# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import hashlib
import logging as log
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from unittest import TestCase
import attr
import cv2
import numpy as np
from attr import attrib, attrs
from datumaro.components.annotation import (
Annotation,
AnnotationType,
Bbox,
Label,
LabelCategories,
MaskCategories,
PointsCategories,
)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset import Dataset, DatasetItemStorage, IDataset
from datumaro.components.errors import (
AnnotationsTooCloseError,
ConflictingCategoriesError,
DatasetMergeError,
FailedAttrVotingError,
FailedLabelVotingError,
MediaTypeError,
MismatchingAttributesError,
MismatchingImageInfoError,
MismatchingMediaError,
MismatchingMediaPathError,
NoMatchingAnnError,
NoMatchingItemError,
VideoMergeError,
WrongGroupError,
)
from datumaro.components.extractor import CategoriesInfo, DatasetItem
from datumaro.components.media import Image, MediaElement, MultiframeImage, PointCloud, Video
from datumaro.util import filter_dict, find
from datumaro.util.annotation_util import (
OKS,
approximate_line,
bbox_iou,
find_instances,
max_bbox,
mean_bbox,
segment_iou,
)
from datumaro.util.attrs_util import default_if_none, ensure_cls
[docs]def match_annotations_equal(a, b):
matches = []
a_unmatched = a[:]
b_unmatched = b[:]
for a_ann in a:
for b_ann in b_unmatched:
if a_ann != b_ann:
continue
matches.append((a_ann, b_ann))
a_unmatched.remove(a_ann)
b_unmatched.remove(b_ann)
break
return matches, a_unmatched, b_unmatched
[docs]def merge_annotations_equal(a, b):
matches, a_unmatched, b_unmatched = match_annotations_equal(a, b)
return [ann_a for (ann_a, _) in matches] + a_unmatched + b_unmatched
[docs]def merge_categories(sources):
categories = {}
for source_idx, source in enumerate(sources):
for cat_type, source_cat in source.items():
existing_cat = categories.setdefault(cat_type, source_cat)
if existing_cat != source_cat and len(source_cat) != 0:
if len(existing_cat) == 0:
categories[cat_type] = source_cat
else:
raise ConflictingCategoriesError(
"Merging of datasets with different categories is "
"only allowed in 'merge' command.",
sources=list(range(source_idx)),
)
return categories
[docs]class MergingStrategy(CliPlugin):
[docs] @classmethod
def merge(cls, sources, **options):
instance = cls(**options)
return instance(sources)
[docs] def __init__(self, **options):
super().__init__(**options)
self.__dict__["_sources"] = None
def __call__(self, sources):
raise NotImplementedError()
[docs]class ExactMerge:
"""
Merges several datasets using the "simple" algorithm:
- items are matched by (id, subset) pairs
- matching items share the media info available:
- nothing + nothing = nothing
- nothing + something = something
- something A + something B = conflict
- annotations are matched by value and shared
- in case of conflicts, throws an error
"""
[docs] @classmethod
def merge(cls, *sources: IDataset) -> DatasetItemStorage:
items = DatasetItemStorage()
for source_idx, source in enumerate(sources):
for item in source:
existing_item = items.get(item.id, item.subset)
if existing_item is not None:
try:
item = cls._merge_items(existing_item, item)
except DatasetMergeError as e:
e.sources = set(range(source_idx))
raise e
items.put(item)
return items
@classmethod
def _merge_items(cls, existing_item: DatasetItem, current_item: DatasetItem) -> DatasetItem:
return existing_item.wrap(
media=cls._merge_media(existing_item, current_item),
attributes=cls._merge_attrs(
existing_item.attributes,
current_item.attributes,
item_id=(existing_item.id, existing_item.subset),
),
annotations=cls._merge_anno(existing_item.annotations, current_item.annotations),
)
@staticmethod
def _merge_attrs(a: Dict[str, Any], b: Dict[str, Any], item_id: Tuple[str, str]) -> Dict:
merged = {}
for name in a.keys() | b.keys():
a_val = a.get(name, None)
b_val = b.get(name, None)
if name not in a:
m_val = b_val
elif name not in b:
m_val = a_val
elif a_val != b_val:
raise MismatchingAttributesError(item_id, name, a_val, b_val)
else:
m_val = a_val
merged[name] = m_val
return merged
@classmethod
def _merge_media(
cls, item_a: DatasetItem, item_b: DatasetItem
) -> Union[Image, PointCloud, Video]:
if (not item_a.media or isinstance(item_a.media, Image)) and (
not item_b.media or isinstance(item_b.media, Image)
):
media = cls._merge_images(item_a, item_b)
elif (not item_a.media or isinstance(item_a.media, PointCloud)) and (
not item_b.media or isinstance(item_b.media, PointCloud)
):
media = cls._merge_point_clouds(item_a, item_b)
elif (not item_a.media or isinstance(item_a.media, Video)) and (
not item_b.media or isinstance(item_b.media, Video)
):
media = cls._merge_videos(item_a, item_b)
elif (not item_a.media or isinstance(item_a.media, MultiframeImage)) and (
not item_b.media or isinstance(item_b.media, MultiframeImage)
):
media = cls._merge_multiframe_images(item_a, item_b)
elif (not item_a.media or isinstance(item_a.media, MediaElement)) and (
not item_b.media or isinstance(item_b.media, MediaElement)
):
if isinstance(item_a.media, MediaElement) and isinstance(item_b.media, MediaElement):
if (
item_a.media.path
and item_b.media.path
and item_a.media.path != item_b.media.path
):
raise MismatchingMediaPathError(
(item_a.id, item_a.subset), item_a.media.path, item_b.media.path
)
if item_a.media.path:
media = item_a.media
else:
media = item_b.media
elif isinstance(item_a.media, MediaElement):
media = item_a.media
else:
media = item_b.media
else:
raise MismatchingMediaError((item_a.id, item_a.subset), item_a.media, item_b.media)
return media
@staticmethod
def _merge_images(item_a: DatasetItem, item_b: DatasetItem) -> Image:
media = None
if isinstance(item_a.media, Image) and isinstance(item_b.media, Image):
if (
item_a.media.path
and item_b.media.path
and item_a.media.path != item_b.media.path
and item_a.media.has_data is item_b.media.has_data
):
# We use has_data as a replacement for path existence check
# - If only one image has data, we'll use it. The other
# one is just a path metainfo, which is not significant
# in this case.
# - If both images have data or both don't, we need
# to compare paths.
#
# Different paths can aclually point to the same file,
# but it's not the case we'd like to allow here to be
# a "simple" merging strategy used for extractor joining
raise MismatchingMediaPathError(
(item_a.id, item_a.subset), item_a.media.path, item_b.media.path
)
if (
item_a.media.has_size
and item_b.media.has_size
and item_a.media.size != item_b.media.size
):
raise MismatchingImageInfoError(
(item_a.id, item_a.subset), item_a.media.size, item_b.media.size
)
# Avoid direct comparison here for better performance
# If there are 2 "data-only" images, they won't be compared and
# we just use the first one
if item_a.media.has_data:
media = item_a.media
elif item_b.media.has_data:
media = item_b.media
elif item_a.media.path:
media = item_a.media
elif item_b.media.path:
media = item_b.media
elif item_a.media.has_size:
media = item_a.media
elif item_b.media.has_size:
media = item_b.media
else:
assert False, "Unknown image field combination"
if not media.has_data or not media.has_size:
if item_a.media._size:
media._size = item_a.media._size
elif item_b.media._size:
media._size = item_b.media._size
elif isinstance(item_a.media, Image):
media = item_a.media
else:
media = item_b.media
return media
@staticmethod
def _merge_point_clouds(item_a: DatasetItem, item_b: DatasetItem) -> PointCloud:
media = None
if isinstance(item_a.media, PointCloud) and isinstance(item_b.media, PointCloud):
if item_a.media.path and item_b.media.path and item_a.media.path != item_b.media.path:
raise MismatchingMediaPathError(
(item_a.id, item_a.subset), item_a.media.path, item_b.media.path
)
if item_a.media.path or item_a.media.extra_images:
media = item_a.media
if item_b.media.extra_images:
for image in item_b.media.extra_images:
if image not in media.extra_images:
media.extra_images.append(image)
else:
media = item_b.media
if item_a.media.extra_images:
for image in item_a.media.extra_images:
if image not in media.extra_images:
media.extra_images.append(image)
elif isinstance(item_a.media, PointCloud):
media = item_a.media
else:
media = item_b.media
return media
@staticmethod
def _merge_videos(item_a: DatasetItem, item_b: DatasetItem) -> Video:
media = None
if isinstance(item_a.media, Video) and isinstance(item_b.media, Video):
if (
item_a.media.path is not item_b.media.path
or item_a.media._start_frame is not item_b.media._start_frame
or item_a.media._end_frame is not item_b.media._end_frame
or item_a.media._step is not item_b.media._step
):
raise VideoMergeError(item_a.id)
media = item_a.media
elif isinstance(item_a.media, Video):
media = item_a.media
else:
media = item_b.media
return media
@staticmethod
def _merge_multiframe_images(item_a: DatasetItem, item_b: DatasetItem) -> MultiframeImage:
media = None
if isinstance(item_a.media, MultiframeImage) and isinstance(item_b.media, MultiframeImage):
if item_a.media.path and item_b.media.path and item_a.media.path != item_b.media.path:
raise MismatchingMediaPathError(
(item_a.id, item_a.subset), item_a.media.path, item_b.media.path
)
if item_a.media.path or item_a.media.data:
media = item_a.media
if item_b.media.data:
for image in item_b.media.data:
if image not in media.data:
media.data.append(image)
else:
media = item_b.media
if item_a.media.data:
for image in item_a.media.data:
if image not in media.data:
media.data.append(image)
elif isinstance(item_a.media, MultiframeImage):
media = item_a.media
else:
media = item_b.media
return media
@staticmethod
def _merge_anno(a: Iterable[Annotation], b: Iterable[Annotation]) -> List[Annotation]:
return merge_annotations_equal(a, b)
[docs] @staticmethod
def merge_categories(sources: Iterable[IDataset]) -> CategoriesInfo:
return merge_categories(sources)
[docs]@attrs
class IntersectMerge(MergingStrategy):
[docs] @attrs(repr_ns="IntersectMerge", kw_only=True)
class Conf:
pairwise_dist = attrib(converter=float, default=0.5)
sigma = attrib(converter=list, factory=list)
output_conf_thresh = attrib(converter=float, default=0)
quorum = attrib(converter=int, default=0)
ignored_attributes = attrib(converter=set, factory=set)
def _groups_converter(value):
result = []
for group in value:
rg = set()
for label in group:
optional = label.endswith("?")
name = label if not optional else label[:-1]
rg.add((name, optional))
result.append(rg)
return result
groups = attrib(converter=_groups_converter, factory=list)
close_distance = attrib(converter=float, default=0.75)
conf = attrib(converter=ensure_cls(Conf), factory=Conf)
# Error trackers:
errors = attrib(factory=list, init=False)
[docs] def add_item_error(self, error, *args, **kwargs):
self.errors.append(error(self._item_id, *args, **kwargs))
# Indexes:
_dataset_map = attrib(init=False) # id(dataset) -> (dataset, index)
_item_map = attrib(init=False) # id(item) -> (item, id(dataset))
_ann_map = attrib(init=False) # id(ann) -> (ann, id(item))
_item_id = attrib(init=False)
_item = attrib(init=False)
# Misc.
_categories = attrib(init=False) # merged categories
def __call__(self, datasets):
self._categories = self._merge_categories([d.categories() for d in datasets])
merged = Dataset(
categories=self._categories, media_type=ExactMerge.merge_media_types(datasets)
)
self._check_groups_definition()
item_matches, item_map = self.match_items(datasets)
self._item_map = item_map
self._dataset_map = {id(d): (d, i) for i, d in enumerate(datasets)}
for item_id, items in item_matches.items():
self._item_id = item_id
if len(items) < len(datasets):
missing_sources = set(id(s) for s in datasets) - set(items)
missing_sources = [self._dataset_map[s][1] for s in missing_sources]
self.add_item_error(NoMatchingItemError, sources=missing_sources)
merged.put(self.merge_items(items))
return merged
[docs] def get_ann_source(self, ann_id):
return self._item_map[self._ann_map[ann_id][1]][1]
[docs] def merge_items(self, items):
self._item = next(iter(items.values()))
self._ann_map = {}
sources = []
for item in items.values():
self._ann_map.update({id(a): (a, id(item)) for a in item.annotations})
sources.append(item.annotations)
log.debug(
"Merging item %s: source annotations %s" % (self._item_id, list(map(len, sources)))
)
annotations = self.merge_annotations(sources)
annotations = [
a for a in annotations if self.conf.output_conf_thresh <= a.attributes.get("score", 1)
]
return self._item.wrap(annotations=annotations)
[docs] def merge_annotations(self, sources):
self._make_mergers(sources)
clusters = self._match_annotations(sources)
joined_clusters = sum(clusters.values(), [])
group_map = self._find_cluster_groups(joined_clusters)
annotations = []
for t, clusters in clusters.items():
for cluster in clusters:
self._check_cluster_sources(cluster)
merged_clusters = self._merge_clusters(t, clusters)
for merged_ann, cluster in zip(merged_clusters, clusters):
attributes = self._find_cluster_attrs(cluster, merged_ann)
attributes = {
k: v for k, v in attributes.items() if k not in self.conf.ignored_attributes
}
attributes.update(merged_ann.attributes)
merged_ann.attributes = attributes
new_group_id = find(enumerate(group_map), lambda e: id(cluster) in e[1][0])
if new_group_id is None:
new_group_id = 0
else:
new_group_id = new_group_id[0] + 1
merged_ann.group = new_group_id
if self.conf.close_distance:
self._check_annotation_distance(t, merged_clusters)
annotations += merged_clusters
if self.conf.groups:
self._check_groups(annotations)
return annotations
[docs] @staticmethod
def match_items(datasets):
item_ids = set((item.id, item.subset) for d in datasets for item in d)
item_map = {} # id(item) -> (item, id(dataset))
matches = OrderedDict()
for (item_id, item_subset) in sorted(item_ids, key=lambda e: e[0]):
items = {}
for d in datasets:
item = d.get(item_id, subset=item_subset)
if item:
items[id(d)] = item
item_map[id(item)] = (item, id(d))
matches[(item_id, item_subset)] = items
return matches, item_map
def _merge_label_categories(self, sources):
same = True
common = None
for src_categories in sources:
src_cat = src_categories.get(AnnotationType.label)
if common is None:
common = src_cat
elif common != src_cat:
same = False
break
if same:
return common
dst_cat = LabelCategories()
for src_id, src_categories in enumerate(sources):
src_cat = src_categories.get(AnnotationType.label)
if src_cat is None:
continue
for src_label in src_cat.items:
dst_label = dst_cat.find(src_label.name, src_label.parent)[1]
if dst_label is not None:
if dst_label != src_label:
if (
src_label.parent
and dst_label.parent
and src_label.parent != dst_label.parent
):
raise ConflictingCategoriesError(
"Can't merge label category %s (from #%s): "
"parent label conflict: %s vs. %s"
% (src_label.name, src_id, src_label.parent, dst_label.parent),
sources=list(range(src_id)),
)
dst_label.parent = dst_label.parent or src_label.parent
dst_label.attributes |= src_label.attributes
else:
pass
else:
dst_cat.add(src_label.name, src_label.parent, src_label.attributes)
return dst_cat
def _merge_point_categories(self, sources, label_cat):
dst_point_cat = PointsCategories()
for src_id, src_categories in enumerate(sources):
src_label_cat = src_categories.get(AnnotationType.label)
src_point_cat = src_categories.get(AnnotationType.points)
if src_label_cat is None or src_point_cat is None:
continue
for src_label_id, src_cat in src_point_cat.items.items():
src_label = src_label_cat.items[src_label_id].name
src_parent_label = src_label_cat.items[src_label_id].parent
dst_label_id = label_cat.find(src_label, src_parent_label)[0]
dst_cat = dst_point_cat.items.get(dst_label_id)
if dst_cat is not None:
if dst_cat != src_cat:
raise ConflictingCategoriesError(
"Can't merge point category for label "
"%s (from #%s): %s vs. %s" % (src_label, src_id, src_cat, dst_cat),
sources=list(range(src_id)),
)
else:
pass
else:
dst_point_cat.add(dst_label_id, src_cat.labels, src_cat.joints)
if len(dst_point_cat.items) == 0:
return None
return dst_point_cat
def _merge_mask_categories(self, sources, label_cat):
dst_mask_cat = MaskCategories()
for src_id, src_categories in enumerate(sources):
src_label_cat = src_categories.get(AnnotationType.label)
src_mask_cat = src_categories.get(AnnotationType.mask)
if src_label_cat is None or src_mask_cat is None:
continue
for src_label_id, src_cat in src_mask_cat.colormap.items():
src_label = src_label_cat.items[src_label_id].name
src_parent_label = src_label_cat.items[src_label_id].parent
dst_label_id = label_cat.find(src_label, src_parent_label)[0]
dst_cat = dst_mask_cat.colormap.get(dst_label_id)
if dst_cat is not None:
if dst_cat != src_cat:
raise ConflictingCategoriesError(
"Can't merge mask category for label "
"%s (from #%s): %s vs. %s" % (src_label, src_id, src_cat, dst_cat),
sources=list(range(src_id)),
)
else:
pass
else:
dst_mask_cat.colormap[dst_label_id] = src_cat
if len(dst_mask_cat.colormap) == 0:
return None
return dst_mask_cat
def _merge_categories(self, sources):
dst_categories = {}
label_cat = self._merge_label_categories(sources)
if label_cat is None:
label_cat = LabelCategories()
dst_categories[AnnotationType.label] = label_cat
points_cat = self._merge_point_categories(sources, label_cat)
if points_cat is not None:
dst_categories[AnnotationType.points] = points_cat
mask_cat = self._merge_mask_categories(sources, label_cat)
if mask_cat is not None:
dst_categories[AnnotationType.mask] = mask_cat
return dst_categories
def _match_annotations(self, sources):
all_by_type = {}
for s in sources:
src_by_type = {}
for a in s:
src_by_type.setdefault(a.type, []).append(a)
for k, v in src_by_type.items():
all_by_type.setdefault(k, []).append(v)
clusters = {}
for k, v in all_by_type.items():
clusters.setdefault(k, []).extend(self._match_ann_type(k, v))
return clusters
def _make_mergers(self, sources):
def _make(c, **kwargs):
kwargs.update(attr.asdict(self.conf))
fields = attr.fields_dict(c)
return c(**{k: v for k, v in kwargs.items() if k in fields}, context=self)
def _for_type(t, **kwargs):
if t is AnnotationType.label:
return _make(LabelMerger, **kwargs)
elif t is AnnotationType.bbox:
return _make(BboxMerger, **kwargs)
elif t is AnnotationType.mask:
return _make(MaskMerger, **kwargs)
elif t is AnnotationType.polygon:
return _make(PolygonMerger, **kwargs)
elif t is AnnotationType.polyline:
return _make(LineMerger, **kwargs)
elif t is AnnotationType.points:
return _make(PointsMerger, **kwargs)
elif t is AnnotationType.caption:
return _make(CaptionsMerger, **kwargs)
elif t is AnnotationType.cuboid_3d:
return _make(Cuboid3dMerger, **kwargs)
elif t is AnnotationType.super_resolution_annotation:
return _make(ImageAnnotationMerger, **kwargs)
elif t is AnnotationType.depth_annotation:
return _make(ImageAnnotationMerger, **kwargs)
elif t is AnnotationType.skeleton:
# to do: add skeletons merge
return _make(ImageAnnotationMerger, **kwargs)
else:
raise NotImplementedError("Type %s is not supported" % t)
instance_map = {}
for s in sources:
s_instances = find_instances(s)
for inst in s_instances:
inst_bbox = max_bbox(
[
a
for a in inst
if a.type
in {AnnotationType.polygon, AnnotationType.mask, AnnotationType.bbox}
]
)
for ann in inst:
instance_map[id(ann)] = [inst, inst_bbox]
self._mergers = {t: _for_type(t, instance_map=instance_map) for t in AnnotationType}
def _match_ann_type(self, t, sources):
return self._mergers[t].match_annotations(sources)
def _merge_clusters(self, t, clusters):
return self._mergers[t].merge_clusters(clusters)
@staticmethod
def _find_cluster_groups(clusters):
cluster_groups = []
visited = set()
for a_idx, cluster_a in enumerate(clusters):
if a_idx in visited:
continue
visited.add(a_idx)
cluster_group = {id(cluster_a)}
# find segment groups in the cluster group
a_groups = set(ann.group for ann in cluster_a)
for cluster_b in clusters[a_idx + 1 :]:
b_groups = set(ann.group for ann in cluster_b)
if a_groups & b_groups:
a_groups |= b_groups
# now we know all the segment groups in this cluster group
# so we can find adjacent clusters
for b_idx, cluster_b in enumerate(clusters[a_idx + 1 :]):
b_idx = a_idx + 1 + b_idx
b_groups = set(ann.group for ann in cluster_b)
if a_groups & b_groups:
cluster_group.add(id(cluster_b))
visited.add(b_idx)
if a_groups == {0}:
continue # skip annotations without a group
cluster_groups.append((cluster_group, a_groups))
return cluster_groups
def _find_cluster_attrs(self, cluster, ann):
quorum = self.conf.quorum or 0
# TODO: when attribute types are implemented, add linear
# interpolation for contiguous values
attr_votes = {} # name -> { value: score , ... }
for s in cluster:
for name, value in s.attributes.items():
votes = attr_votes.get(name, {})
votes[value] = 1 + votes.get(value, 0)
attr_votes[name] = votes
attributes = {}
for name, votes in attr_votes.items():
winner, count = max(votes.items(), key=lambda e: e[1])
if count < quorum:
if sum(votes.values()) < quorum:
# blame provokers
missing_sources = set(
self.get_ann_source(id(a))
for a in cluster
if s.attributes.get(name) == winner
)
else:
# blame outliers
missing_sources = set(
self.get_ann_source(id(a))
for a in cluster
if s.attributes.get(name) != winner
)
missing_sources = [self._dataset_map[s][1] for s in missing_sources]
self.add_item_error(
FailedAttrVotingError, name, votes, ann, sources=missing_sources
)
continue
attributes[name] = winner
return attributes
def _check_cluster_sources(self, cluster):
if len(cluster) == len(self._dataset_map):
return
def _has_item(s):
item = self._dataset_map[s][0].get(*self._item_id)
if not item:
return False
if len(item.annotations) == 0:
return False
return True
missing_sources = set(self._dataset_map) - set(self.get_ann_source(id(a)) for a in cluster)
missing_sources = [self._dataset_map[s][1] for s in missing_sources if _has_item(s)]
if missing_sources:
self.add_item_error(NoMatchingAnnError, cluster[0], sources=missing_sources)
def _check_annotation_distance(self, t, annotations):
for a_idx, a_ann in enumerate(annotations):
for b_ann in annotations[a_idx + 1 :]:
d = self._mergers[t].distance(a_ann, b_ann)
if self.conf.close_distance < d:
self.add_item_error(AnnotationsTooCloseError, a_ann, b_ann, d)
def _check_groups(self, annotations):
check_groups = []
for check_group_raw in self.conf.groups:
check_group = set(l[0] for l in check_group_raw)
optional = set(l[0] for l in check_group_raw if l[1])
check_groups.append((check_group, optional))
def _check_group(group_labels, group):
for check_group, optional in check_groups:
common = check_group & group_labels
real_miss = check_group - common - optional
extra = group_labels - check_group
if common and (extra or real_miss):
self.add_item_error(WrongGroupError, group_labels, check_group, group)
break
groups = find_instances(annotations)
for group in groups:
group_labels = set()
for ann in group:
if not hasattr(ann, "label"):
continue
label = self._get_label_name(ann.label)
if ann.group:
group_labels.add(label)
else:
_check_group({label}, [ann])
if not group_labels:
continue
_check_group(group_labels, group)
def _get_label_name(self, label_id):
if label_id is None:
return None
return self._categories[AnnotationType.label].items[label_id].name
def _get_label_id(self, label, parent=""):
if label is not None:
return self._categories[AnnotationType.label].find(label, parent)[0]
return None
def _get_src_label_name(self, ann, label_id):
if label_id is None:
return None
item_id = self._ann_map[id(ann)][1]
dataset_id = self._item_map[item_id][1]
return (
self._dataset_map[dataset_id][0].categories()[AnnotationType.label].items[label_id].name
)
def _get_any_label_name(self, ann, label_id):
if label_id is None:
return None
try:
return self._get_src_label_name(ann, label_id)
except KeyError:
return self._get_label_name(label_id)
def _check_groups_definition(self):
for group in self.conf.groups:
for label, _ in group:
_, entry = self._categories[AnnotationType.label].find(label)
if entry is None:
raise ValueError(
"Datasets do not contain "
"label '%s', available labels %s"
% (label, [i.name for i in self._categories[AnnotationType.label].items])
)
[docs]@attrs(kw_only=True)
class AnnotationMatcher:
_context: Optional[IntersectMerge] = attrib(default=None)
def match_annotations(self, sources):
raise NotImplementedError()
[docs]@attrs
class LabelMatcher(AnnotationMatcher):
def distance(self, a, b):
a_label = self._context._get_any_label_name(a, a.label)
b_label = self._context._get_any_label_name(b, b.label)
return a_label == b_label
def match_annotations(self, sources):
return [sum(sources, [])]
[docs]@attrs(kw_only=True)
class _ShapeMatcher(AnnotationMatcher):
pairwise_dist = attrib(converter=float, default=0.9)
cluster_dist = attrib(converter=float, default=-1.0)
[docs] def match_annotations(self, sources):
distance = self.distance
label_matcher = self.label_matcher
pairwise_dist = self.pairwise_dist
cluster_dist = self.cluster_dist
if cluster_dist < 0:
cluster_dist = pairwise_dist
id_segm = {id(a): (a, id(s)) for s in sources for a in s}
def _is_close_enough(cluster, extra_id):
# check if whole cluster IoU will not be broken
# when this segment is added
b = id_segm[extra_id][0]
for a_id in cluster:
a = id_segm[a_id][0]
if distance(a, b) < cluster_dist:
return False
return True
def _has_same_source(cluster, extra_id):
b = id_segm[extra_id][1]
for a_id in cluster:
a = id_segm[a_id][1]
if a == b:
return True
return False
# match segments in sources, pairwise
adjacent = {i: [] for i in id_segm} # id(sgm) -> [id(adj_sgm1), ...]
for a_idx, src_a in enumerate(sources):
for src_b in sources[a_idx + 1 :]:
matches, _, _, _ = match_segments(
src_a,
src_b,
dist_thresh=pairwise_dist,
distance=distance,
label_matcher=label_matcher,
)
for a, b in matches:
adjacent[id(a)].append(id(b))
# join all segments into matching clusters
clusters = []
visited = set()
for cluster_idx in adjacent:
if cluster_idx in visited:
continue
cluster = set()
to_visit = {cluster_idx}
while to_visit:
c = to_visit.pop()
cluster.add(c)
visited.add(c)
for i in adjacent[c]:
if i in visited:
continue
if 0 < cluster_dist and not _is_close_enough(cluster, i):
continue
if _has_same_source(cluster, i):
continue
to_visit.add(i)
clusters.append([id_segm[i][0] for i in cluster])
return clusters
[docs] def distance(self, a, b):
return segment_iou(a, b)
[docs] def label_matcher(self, a, b):
a_label = self._context._get_any_label_name(a, a.label)
b_label = self._context._get_any_label_name(b, b.label)
return a_label == b_label
[docs]@attrs
class BboxMatcher(_ShapeMatcher):
pass
[docs]@attrs
class PolygonMatcher(_ShapeMatcher):
pass
[docs]@attrs
class MaskMatcher(_ShapeMatcher):
pass
[docs]@attrs(kw_only=True)
class PointsMatcher(_ShapeMatcher):
sigma: Optional[list] = attrib(default=None)
instance_map = attrib(converter=dict)
[docs] def distance(self, a, b):
a_bbox = self.instance_map[id(a)][1]
b_bbox = self.instance_map[id(b)][1]
if bbox_iou(a_bbox, b_bbox) <= 0:
return 0
bbox = mean_bbox([a_bbox, b_bbox])
return OKS(a, b, sigma=self.sigma, bbox=bbox)
[docs]@attrs
class LineMatcher(_ShapeMatcher):
[docs] def distance(self, a, b):
# Compute inter-line area by using the Trapezoid formulae
# https://en.wikipedia.org/wiki/Trapezoidal_rule
# Normalize by common bbox and get the bbox fill ratio
# Call this ratio the "distance"
# The box area is an early-exit filter for non-intersected figures
bbox = max_bbox([a, b])
box_area = bbox[2] * bbox[3]
if not box_area:
return 1
def _approx(line, segments):
if len(line) // 2 != segments + 1:
line = approximate_line(line, segments=segments)
return np.reshape(line, (-1, 2))
segments = max(len(a.points) // 2, len(b.points) // 2, 5) - 1
a = _approx(a.points, segments)
b = _approx(b.points, segments)
dists = np.linalg.norm(a - b, axis=1)
dists = dists[:-1] + dists[1:]
a_steps = np.linalg.norm(a[1:] - a[:-1], axis=1)
b_steps = np.linalg.norm(b[1:] - b[:-1], axis=1)
# For the common bbox we can't use
# - the AABB (axis-alinged bbox) of a point set
# - the exterior of a point set
# - the convex hull of a point set
# because these soultions won't be correctly normalized.
# The lines can have multiple self-intersections, which can give
# the inter-line area more than internal area of the options above,
# producing the value of the distance outside of the [0; 1] range.
#
# Instead, we can compute the upper boundary for the inter-line
# area based on the maximum point distance and line length.
max_area = np.max(dists) * max(np.sum(a_steps), np.sum(b_steps))
area = np.dot(dists, a_steps + b_steps) * 0.5 * 0.5 / max(max_area, 1.0)
return abs(1 - area)
[docs]@attrs
class CaptionsMatcher(AnnotationMatcher):
[docs] def match_annotations(self, sources):
raise NotImplementedError()
[docs]@attrs
class Cuboid3dMatcher(_ShapeMatcher):
[docs] def distance(self, a, b):
raise NotImplementedError()
@attrs
class ImageAnnotationMatcher(AnnotationMatcher):
def match_annotations(self, sources):
raise NotImplementedError()
[docs]@attrs(kw_only=True)
class AnnotationMerger:
[docs] def merge_clusters(self, clusters):
raise NotImplementedError()
[docs]@attrs(kw_only=True)
class LabelMerger(AnnotationMerger, LabelMatcher):
quorum = attrib(converter=int, default=0)
[docs] def merge_clusters(self, clusters):
assert len(clusters) <= 1
if len(clusters) == 0:
return []
votes = {} # label -> score
for ann in clusters[0]:
label = self._context._get_src_label_name(ann, ann.label)
votes[label] = 1 + votes.get(label, 0)
merged = []
for label, count in votes.items():
if count < self.quorum:
sources = set(
self.get_ann_source(id(a))
for a in clusters[0]
if label not in [self._context._get_src_label_name(l, l.label) for l in a]
)
sources = [self._context._dataset_map[s][1] for s in sources]
self._context.add_item_error(FailedLabelVotingError, votes, sources=sources)
continue
merged.append(
Label(
self._context._get_label_id(label),
attributes={"score": count / len(self._context._dataset_map)},
)
)
return merged
[docs]@attrs(kw_only=True)
class _ShapeMerger(AnnotationMerger, _ShapeMatcher):
quorum = attrib(converter=int, default=0)
[docs] def merge_clusters(self, clusters):
return list(map(self.merge_cluster, clusters))
[docs] def find_cluster_label(self, cluster):
votes = {}
for s in cluster:
label = self._context._get_src_label_name(s, s.label)
state = votes.setdefault(label, [0, 0])
state[0] += s.attributes.get("score", 1.0)
state[1] += 1
label, (score, count) = max(votes.items(), key=lambda e: e[1][0])
if count < self.quorum:
self._context.add_item_error(FailedLabelVotingError, votes)
label = None
score = score / len(self._context._dataset_map)
label = self._context._get_label_id(label)
return label, score
@staticmethod
def _merge_cluster_shape_mean_box_nearest(cluster):
mbbox = Bbox(*mean_bbox(cluster))
dist = (segment_iou(mbbox, s) for s in cluster)
nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1])
return cluster[nearest_pos]
[docs] def merge_cluster_shape(self, cluster):
shape = self._merge_cluster_shape_mean_box_nearest(cluster)
shape_score = sum(max(0, self.distance(shape, s)) for s in cluster) / len(cluster)
return shape, shape_score
[docs] def merge_cluster(self, cluster):
label, label_score = self.find_cluster_label(cluster)
shape, shape_score = self.merge_cluster_shape(cluster)
shape.z_order = max(cluster, key=lambda a: a.z_order).z_order
shape.label = label
shape.attributes["score"] = label_score * shape_score if label is not None else shape_score
return shape
[docs]@attrs
class BboxMerger(_ShapeMerger, BboxMatcher):
pass
[docs]@attrs
class PolygonMerger(_ShapeMerger, PolygonMatcher):
pass
[docs]@attrs
class MaskMerger(_ShapeMerger, MaskMatcher):
pass
[docs]@attrs
class PointsMerger(_ShapeMerger, PointsMatcher):
pass
[docs]@attrs
class LineMerger(_ShapeMerger, LineMatcher):
pass
[docs]@attrs
class CaptionsMerger(AnnotationMerger, CaptionsMatcher):
pass
[docs]@attrs
class Cuboid3dMerger(_ShapeMerger, Cuboid3dMatcher):
@staticmethod
def _merge_cluster_shape_mean_box_nearest(cluster):
raise NotImplementedError()
# mbbox = Bbox(*mean_cuboid(cluster))
# dist = (segment_iou(mbbox, s) for s in cluster)
# nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1])
# return cluster[nearest_pos]
[docs] def merge_cluster(self, cluster):
label, label_score = self.find_cluster_label(cluster)
shape, shape_score = self.merge_cluster_shape(cluster)
shape.label = label
shape.attributes["score"] = label_score * shape_score if label is not None else shape_score
return shape
@attrs
class ImageAnnotationMerger(AnnotationMerger, ImageAnnotationMatcher):
pass
[docs]def match_segments(
a_segms,
b_segms,
distance=segment_iou,
dist_thresh=1.0,
label_matcher=lambda a, b: a.label == b.label,
):
assert callable(distance), distance
assert callable(label_matcher), label_matcher
a_segms.sort(key=lambda ann: 1 - ann.attributes.get("score", 1))
b_segms.sort(key=lambda ann: 1 - ann.attributes.get("score", 1))
# a_matches: indices of b_segms matched to a bboxes
# b_matches: indices of a_segms matched to b bboxes
a_matches = -np.ones(len(a_segms), dtype=int)
b_matches = -np.ones(len(b_segms), dtype=int)
distances = np.array([[distance(a, b) for b in b_segms] for a in a_segms])
# matches: boxes we succeeded to match completely
# mispred: boxes we succeeded to match, having label mismatch
matches = []
mispred = []
for a_idx, a_segm in enumerate(a_segms):
if len(b_segms) == 0:
break
matched_b = -1
max_dist = -1
b_indices = np.argsort(
[not label_matcher(a_segm, b_segm) for b_segm in b_segms], kind="stable"
) # prioritize those with same label, keep score order
for b_idx in b_indices:
if 0 <= b_matches[b_idx]: # assign a_segm with max conf
continue
d = distances[a_idx, b_idx]
if d < dist_thresh or d <= max_dist:
continue
max_dist = d
matched_b = b_idx
if matched_b < 0:
continue
a_matches[a_idx] = matched_b
b_matches[matched_b] = a_idx
b_segm = b_segms[matched_b]
if label_matcher(a_segm, b_segm):
matches.append((a_segm, b_segm))
else:
mispred.append((a_segm, b_segm))
# *_umatched: boxes of (*) we failed to match
a_unmatched = [a_segms[i] for i, m in enumerate(a_matches) if m < 0]
b_unmatched = [b_segms[i] for i, m in enumerate(b_matches) if m < 0]
return matches, mispred, a_unmatched, b_unmatched
[docs]def mean_std(dataset: IDataset):
counter = _MeanStdCounter()
for item in dataset:
counter.accumulate(item)
return counter.get_result()
[docs]class _MeanStdCounter:
"""
Computes unbiased mean and std. dev. for dataset images, channel-wise.
"""
[docs] def __init__(self):
self._stats = {} # (id, subset) -> (pixel count, mean vec, std vec)
[docs] def accumulate(self, item: DatasetItem):
size = item.media.size
if size is None:
log.warning(
"Item %s: can't detect image size, "
"the image will be skipped from pixel statistics",
item.id,
)
return
count = np.prod(item.media.size)
image = item.media.data
if len(image.shape) == 2:
image = image[:, :, np.newaxis]
else:
image = image[:, :, :3]
# opencv is much faster than numpy here
mean, std = cv2.meanStdDev(image.astype(np.double) / 255)
self._stats[(item.id, item.subset)] = (count, mean, std)
[docs] def get_result(self) -> Tuple[Tuple[float, float, float], Tuple[float, float, float]]:
n = len(self._stats)
if n == 0:
return [0, 0, 0], [0, 0, 0]
counts = np.empty(n, dtype=np.uint32)
stats = np.empty((n, 2, 3), dtype=np.double)
for i, v in enumerate(self._stats.values()):
counts[i] = v[0]
stats[i][0] = v[1].reshape(-1)
stats[i][1] = v[2].reshape(-1)
mean = lambda i, s: s[i][0]
var = lambda i, s: s[i][1]
# make variance unbiased
np.multiply(np.square(stats[:, 1]), (counts / (counts - 1))[:, np.newaxis], out=stats[:, 1])
# Use an online algorithm to:
# - handle different image sizes
# - avoid cancellation problem
_, mean, var = self._compute_stats(stats, counts, mean, var)
return mean * 255, np.sqrt(var) * 255
# Implements online parallel computation of sample variance
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
[docs] @staticmethod
def _pairwise_stats(count_a, mean_a, var_a, count_b, mean_b, var_b):
"""
Computes vector mean and variance.
Needed do avoid catastrophic cancellation in floating point computations
Returns:
A tuple (total count, mean, variance)
"""
# allow long arithmetics
count_a = int(count_a)
count_b = int(count_b)
delta = mean_b - mean_a
m_a = var_a * (count_a - 1)
m_b = var_b * (count_b - 1)
M2 = m_a + m_b + delta**2 * (count_a * count_b / (count_a + count_b))
return (count_a + count_b, mean_a * 0.5 + mean_b * 0.5, M2 / (count_a + count_b - 1))
[docs] @staticmethod
def _compute_stats(stats, counts, mean_accessor, variance_accessor):
"""
Recursively computes total count, mean and variance,
does O(log(N)) calls.
Args:
stats: (float array of shape N, 2 * d, d = dimensions of values)
count: (integer array of shape N)
mean_accessor: (function(idx, stats)) to retrieve element mean
variance_accessor: (function(idx, stats)) to retrieve element variance
Returns:
A tuple (total count, mean, variance)
"""
m = mean_accessor
v = variance_accessor
n = len(stats)
if n == 1:
return counts[0], m(0, stats), v(0, stats)
if n == 2:
return __class__._pairwise_stats(
counts[0], m(0, stats), v(0, stats), counts[1], m(1, stats), v(1, stats)
)
h = n // 2
return __class__._pairwise_stats(
*__class__._compute_stats(stats[:h], counts[:h], m, v),
*__class__._compute_stats(stats[h:], counts[h:], m, v),
)
[docs]def compute_image_statistics(dataset: IDataset):
stats = {
"dataset": {
"images count": 0,
"unique images count": 0,
"repeated images count": 0,
"repeated images": [], # [[id1, id2], [id3, id4, id5], ...]
},
"subsets": {},
}
stats_counter = _MeanStdCounter()
unique_counter = _ItemMatcher()
for item in dataset:
stats_counter.accumulate(item)
unique_counter.process_item(item)
def _extractor_stats(subset_name, extractor):
sub_counter = _MeanStdCounter()
sub_counter._stats = {
k: v
for k, v in stats_counter._stats.items()
if subset_name and k[1] == subset_name or not subset_name
}
available = len(sub_counter._stats) != 0
stats = {
"images count": len(extractor),
}
if available:
mean, std = sub_counter.get_result()
stats.update(
{
"image mean": [float(v) for v in mean[::-1]],
"image std": [float(v) for v in std[::-1]],
}
)
else:
stats.update(
{
"image mean": "n/a",
"image std": "n/a",
}
)
return stats
for subset_name in dataset.subsets():
stats["subsets"][subset_name] = _extractor_stats(
subset_name, dataset.get_subset(subset_name)
)
unique_items = unique_counter.get_result()
repeated_items = [sorted(g) for g in unique_items.values() if 1 < len(g)]
stats["dataset"].update(
{
"images count": len(dataset),
"unique images count": len(unique_items),
"repeated images count": len(repeated_items),
"repeated images": repeated_items, # [[id1, id2], [id3, id4, id5], ...]
}
)
return stats
[docs]def compute_ann_statistics(dataset: IDataset):
labels = dataset.categories().get(AnnotationType.label, LabelCategories())
def get_label(ann):
return labels.items[ann.label].name if ann.label is not None else None
stats = {
"images count": 0,
"annotations count": 0,
"unannotated images count": 0,
"unannotated images": [],
"annotations by type": {
t.name: {
"count": 0,
}
for t in AnnotationType
},
"annotations": {},
}
by_type = stats["annotations by type"]
attr_template = {
"count": 0,
"values count": 0,
"values present": set(),
"distribution": {}, # value -> (count, total%)
}
label_stat = {
"count": 0,
"distribution": {l.name: [0, 0] for l in labels.items}, # label -> (count, total%)
"attributes": {},
}
stats["annotations"]["labels"] = label_stat
segm_stat = {
"avg. area": 0,
"area distribution": [], # a histogram with 10 bins
# (min, min+10%), ..., (min+90%, max) -> (count, total%)
"pixel distribution": {l.name: [0, 0] for l in labels.items}, # label -> (count, total%)
}
stats["annotations"]["segments"] = segm_stat
segm_areas = []
pixel_dist = segm_stat["pixel distribution"]
total_pixels = 0
for item in dataset:
if len(item.annotations) == 0:
stats["unannotated images"].append(item.id)
continue
for ann in item.annotations:
by_type[ann.type.name]["count"] += 1
if not hasattr(ann, "label") or ann.label is None:
continue
if ann.type in {AnnotationType.mask, AnnotationType.polygon, AnnotationType.bbox}:
area = ann.get_area()
segm_areas.append(area)
pixel_dist[get_label(ann)][0] += int(area)
label_stat["count"] += 1
label_stat["distribution"][get_label(ann)][0] += 1
for name, value in ann.attributes.items():
if name.lower() in {"occluded", "visibility", "score", "id", "track_id"}:
continue
attrs_stat = label_stat["attributes"].setdefault(name, deepcopy(attr_template))
attrs_stat["count"] += 1
attrs_stat["values present"].add(str(value))
attrs_stat["distribution"].setdefault(str(value), [0, 0])[0] += 1
stats["images count"] = len(dataset)
stats["annotations count"] = sum(t["count"] for t in stats["annotations by type"].values())
stats["unannotated images count"] = len(stats["unannotated images"])
for label_info in label_stat["distribution"].values():
label_info[1] = label_info[0] / (label_stat["count"] or 1)
for label_attr in label_stat["attributes"].values():
label_attr["values count"] = len(label_attr["values present"])
label_attr["values present"] = sorted(label_attr["values present"])
for attr_info in label_attr["distribution"].values():
attr_info[1] = attr_info[0] / (label_attr["count"] or 1)
# numpy.sum might be faster, but could overflow with large datasets.
# Python's int can transparently mutate to be of indefinite precision (long)
total_pixels = sum(int(a) for a in segm_areas)
segm_stat["avg. area"] = total_pixels / (len(segm_areas) or 1.0)
for label_info in segm_stat["pixel distribution"].values():
label_info[1] = label_info[0] / (total_pixels or 1)
if len(segm_areas) != 0:
hist, bins = np.histogram(segm_areas)
segm_stat["area distribution"] = [
{
"min": float(bin_min),
"max": float(bin_max),
"count": int(c),
"percent": int(c) / len(segm_areas),
}
for c, (bin_min, bin_max) in zip(hist, zip(bins[:-1], bins[1:]))
]
return stats
[docs]@attrs
class DistanceComparator:
iou_threshold = attrib(converter=float, default=0.5)
[docs] def match_annotations(self, item_a, item_b):
return {t: self._match_ann_type(t, item_a, item_b) for t in AnnotationType}
def _match_ann_type(self, t, *args):
# pylint: disable=no-value-for-parameter
if t == AnnotationType.label:
return self.match_labels(*args)
elif t == AnnotationType.bbox:
return self.match_boxes(*args)
elif t == AnnotationType.polygon:
return self.match_polygons(*args)
elif t == AnnotationType.mask:
return self.match_masks(*args)
elif t == AnnotationType.points:
return self.match_points(*args)
elif t == AnnotationType.polyline:
return self.match_lines(*args)
# pylint: enable=no-value-for-parameter
else:
raise NotImplementedError("Unexpected annotation type %s" % t)
@staticmethod
def _get_ann_type(t, item):
return [a for a in item.annotations if a.type == t]
[docs] def match_labels(self, item_a, item_b):
a_labels = set(a.label for a in self._get_ann_type(AnnotationType.label, item_a))
b_labels = set(a.label for a in self._get_ann_type(AnnotationType.label, item_b))
matches = a_labels & b_labels
a_unmatched = a_labels - b_labels
b_unmatched = b_labels - a_labels
return matches, a_unmatched, b_unmatched
def _match_segments(self, t, item_a, item_b):
a_boxes = self._get_ann_type(t, item_a)
b_boxes = self._get_ann_type(t, item_b)
return match_segments(a_boxes, b_boxes, dist_thresh=self.iou_threshold)
[docs] def match_polygons(self, item_a, item_b):
return self._match_segments(AnnotationType.polygon, item_a, item_b)
[docs] def match_masks(self, item_a, item_b):
return self._match_segments(AnnotationType.mask, item_a, item_b)
[docs] def match_boxes(self, item_a, item_b):
return self._match_segments(AnnotationType.bbox, item_a, item_b)
[docs] def match_points(self, item_a, item_b):
a_points = self._get_ann_type(AnnotationType.points, item_a)
b_points = self._get_ann_type(AnnotationType.points, item_b)
instance_map = {}
for s in [item_a.annotations, item_b.annotations]:
s_instances = find_instances(s)
for inst in s_instances:
inst_bbox = max_bbox(inst)
for ann in inst:
instance_map[id(ann)] = [inst, inst_bbox]
matcher = PointsMatcher(instance_map=instance_map)
return match_segments(
a_points, b_points, dist_thresh=self.iou_threshold, distance=matcher.distance
)
[docs] def match_lines(self, item_a, item_b):
a_lines = self._get_ann_type(AnnotationType.polyline, item_a)
b_lines = self._get_ann_type(AnnotationType.polyline, item_b)
matcher = LineMatcher()
return match_segments(
a_lines, b_lines, dist_thresh=self.iou_threshold, distance=matcher.distance
)
[docs]def match_items_by_id(a: IDataset, b: IDataset):
a_items = set((item.id, item.subset) for item in a)
b_items = set((item.id, item.subset) for item in b)
matches = a_items & b_items
matches = [([m], [m]) for m in matches]
a_unmatched = a_items - b_items
b_unmatched = b_items - a_items
return matches, a_unmatched, b_unmatched
[docs]def match_items_by_image_hash(a: IDataset, b: IDataset):
a_hash = find_unique_images(a)
b_hash = find_unique_images(b)
a_items = set(a_hash)
b_items = set(b_hash)
matches = a_items & b_items
a_unmatched = a_items - b_items
b_unmatched = b_items - a_items
matches = [(a_hash[h], b_hash[h]) for h in matches]
a_unmatched = set(i for h in a_unmatched for i in a_hash[h])
b_unmatched = set(i for h in b_unmatched for i in b_hash[h])
return matches, a_unmatched, b_unmatched
class _ItemMatcher:
@staticmethod
def _default_item_hash(item: DatasetItem):
if not item.media or not item.media.has_data:
if item.media and item.media.path:
return hash(item.media.path)
log.warning(
"Item (%s, %s) has no image " "info, counted as unique", item.id, item.subset
)
return None
# Disable B303:md5, because the hash is not used in a security context
return hashlib.md5(item.media.data.tobytes()).hexdigest() # nosec
def __init__(self, item_hash: Optional[Callable] = None):
self._hash = item_hash or self._default_item_hash
# hash -> [(id, subset), ...]
self._unique: Dict[str, Set[Tuple[str, str]]] = {}
def process_item(self, item: DatasetItem):
h = self._hash(item)
if h is None:
h = str(id(item)) # anything unique
self._unique.setdefault(h, set()).add((item.id, item.subset))
def get_result(self):
return self._unique
[docs]def find_unique_images(dataset: IDataset, item_hash: Optional[Callable] = None):
matcher = _ItemMatcher(item_hash=item_hash)
for item in dataset:
matcher.process_item(item)
return matcher.get_result()
[docs]def match_classes(a: CategoriesInfo, b: CategoriesInfo):
a_label_cat = a.get(AnnotationType.label, LabelCategories())
b_label_cat = b.get(AnnotationType.label, LabelCategories())
a_labels = set(c.name for c in a_label_cat)
b_labels = set(c.name for c in b_label_cat)
matches = a_labels & b_labels
a_unmatched = a_labels - b_labels
b_unmatched = b_labels - a_labels
return matches, a_unmatched, b_unmatched
[docs]@attrs
class ExactComparator:
match_images: bool = attrib(kw_only=True, default=False)
ignored_fields = attrib(kw_only=True, factory=set, validator=default_if_none(set))
ignored_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set))
ignored_item_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set))
_test: TestCase = attrib(init=False)
errors: list = attrib(init=False)
def __attrs_post_init__(self):
self._test = TestCase()
self._test.maxDiff = None
def _match_items(self, a, b):
if self.match_images:
return match_items_by_image_hash(a, b)
else:
return match_items_by_id(a, b)
def _compare_categories(self, a, b):
test = self._test
errors = self.errors
try:
test.assertEqual(sorted(a, key=lambda t: t.value), sorted(b, key=lambda t: t.value))
except AssertionError as e:
errors.append({"type": "categories", "message": str(e)})
if AnnotationType.label in a:
try:
test.assertEqual(
a[AnnotationType.label].items,
b[AnnotationType.label].items,
)
except AssertionError as e:
errors.append({"type": "labels", "message": str(e)})
if AnnotationType.mask in a:
try:
test.assertEqual(
a[AnnotationType.mask].colormap,
b[AnnotationType.mask].colormap,
)
except AssertionError as e:
errors.append({"type": "colormap", "message": str(e)})
if AnnotationType.points in a:
try:
test.assertEqual(
a[AnnotationType.points].items,
b[AnnotationType.points].items,
)
except AssertionError as e:
errors.append({"type": "points", "message": str(e)})
def _compare_annotations(self, a, b):
ignored_fields = self.ignored_fields
ignored_attrs = self.ignored_attrs
a_fields = {k: None for k in a.as_dict() if k in ignored_fields}
b_fields = {k: None for k in b.as_dict() if k in ignored_fields}
if "attributes" not in ignored_fields:
a_fields["attributes"] = filter_dict(a.attributes, ignored_attrs)
b_fields["attributes"] = filter_dict(b.attributes, ignored_attrs)
result = a.wrap(**a_fields) == b.wrap(**b_fields)
return result
def _compare_items(self, item_a, item_b):
test = self._test
a_id = (item_a.id, item_a.subset)
b_id = (item_b.id, item_b.subset)
matched = []
unmatched = []
errors = []
try:
test.assertEqual(
filter_dict(item_a.attributes, self.ignored_item_attrs),
filter_dict(item_b.attributes, self.ignored_item_attrs),
)
except AssertionError as e:
errors.append({"type": "item_attr", "a_item": a_id, "b_item": b_id, "message": str(e)})
b_annotations = item_b.annotations[:]
for ann_a in item_a.annotations:
ann_b_candidates = [x for x in item_b.annotations if x.type == ann_a.type]
ann_b = find(
enumerate(self._compare_annotations(ann_a, x) for x in ann_b_candidates),
lambda x: x[1],
)
if ann_b is None:
unmatched.append(
{
"item": a_id,
"source": "a",
"ann": str(ann_a),
}
)
continue
else:
ann_b = ann_b_candidates[ann_b[0]]
b_annotations.remove(ann_b) # avoid repeats
matched.append({"a_item": a_id, "b_item": b_id, "a": str(ann_a), "b": str(ann_b)})
for ann_b in b_annotations:
unmatched.append({"item": b_id, "source": "b", "ann": str(ann_b)})
return matched, unmatched, errors
[docs] def compare_datasets(self, a, b):
self.errors = []
errors = self.errors
self._compare_categories(a.categories(), b.categories())
matched = []
unmatched = []
matches, a_unmatched, b_unmatched = self._match_items(a, b)
if a.categories().get(AnnotationType.label) != b.categories().get(AnnotationType.label):
return matched, unmatched, a_unmatched, b_unmatched, errors
_dist = lambda s: len(s[1]) + len(s[2])
for a_ids, b_ids in matches:
# build distance matrix
match_status = {} # (a_id, b_id): [matched, unmatched, errors]
a_matches = {a_id: None for a_id in a_ids}
b_matches = {b_id: None for b_id in b_ids}
for a_id in a_ids:
item_a = a.get(*a_id)
candidates = {}
for b_id in b_ids:
item_b = b.get(*b_id)
i_m, i_um, i_err = self._compare_items(item_a, item_b)
candidates[b_id] = [i_m, i_um, i_err]
if len(i_um) == 0:
a_matches[a_id] = b_id
b_matches[b_id] = a_id
matched.extend(i_m)
errors.extend(i_err)
break
match_status[a_id] = candidates
# assign
for a_id in a_ids:
if len(b_ids) == 0:
break
# find the closest, ignore already assigned
matched_b = a_matches[a_id]
if matched_b is not None:
continue
min_dist = -1
for b_id in b_ids:
if b_matches[b_id] is not None:
continue
d = _dist(match_status[a_id][b_id])
if d < min_dist and 0 <= min_dist:
continue
min_dist = d
matched_b = b_id
if matched_b is None:
continue
a_matches[a_id] = matched_b
b_matches[matched_b] = a_id
m = match_status[a_id][matched_b]
matched.extend(m[0])
unmatched.extend(m[1])
errors.extend(m[2])
a_unmatched |= set(a_id for a_id, m in a_matches.items() if not m)
b_unmatched |= set(b_id for b_id, m in b_matches.items() if not m)
return matched, unmatched, a_unmatched, b_unmatched, errors