Source code for datumaro.plugins.lfw_format

# Copyright (C) 2020-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os
import os.path as osp
import re

from datumaro.components.annotation import AnnotationType, Label, LabelCategories, Points
from datumaro.components.converter import Converter
from datumaro.components.errors import MediaTypeError
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.media import Image
from datumaro.util.image import find_images
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file


[docs]class LfwPath: IMAGES_DIR = "images" ANNOTATION_DIR = "annotations" LANDMARKS_FILE = "landmarks.txt" PAIRS_FILE = "pairs.txt" PEOPLE_FILE = "people.txt" IMAGE_EXT = ".jpg" PATTERN = re.compile(r"([\w-]+)_([-\d]+)")
[docs]class LfwExtractor(SourceExtractor):
[docs] def __init__(self, path, subset=None): if not osp.isfile(path): raise FileNotFoundError("Can't read annotation file '%s'" % path) if not subset: subset = osp.basename(osp.dirname(osp.dirname(path))) super().__init__(subset=subset) self._dataset_dir = osp.dirname(osp.dirname(osp.dirname(path))) self._annotations_dir = osp.dirname(path) self._images_dir = osp.join(self._dataset_dir, self._subset, LfwPath.IMAGES_DIR) people_file = osp.join(osp.dirname(path), LfwPath.PEOPLE_FILE) self._categories = self._load_categories(people_file) self._items = list(self._load_items(path).values())
def _load_categories(self, path): if has_meta_file(self._dataset_dir): return { AnnotationType.label: LabelCategories.from_iterable( parse_meta_file(self._dataset_dir).keys() ) } label_cat = LabelCategories() if osp.isfile(path): with open(path, encoding="utf-8") as labels_file: for line in labels_file: objects = line.strip().split("\t") if len(objects) == 2: label_cat.add(objects[0]) return {AnnotationType.label: label_cat} def _load_items(self, path): items = {} label_categories = self._categories.get(AnnotationType.label) if osp.isdir(self._images_dir): images = { osp.splitext(osp.relpath(p, self._images_dir))[0].replace("\\", "/"): p for p in find_images(self._images_dir, recursive=True) } else: images = {} with open(path, encoding="utf-8") as f: def get_label_id(label_name): if not label_name: return None label_id = label_categories.find(label_name)[0] if label_id is None: label_id = label_categories.add(label_name) return label_id for line in f: pair = line.strip().split("\t") if len(pair) == 1 and pair[0] != "": annotations = [] image = pair[0] item_id = pair[0] objects = item_id.split("/") if 1 < len(objects): label_name = objects[0] label = get_label_id(label_name) if label is not None: annotations.append(Label(label)) item_id = item_id[len(label_name) + 1 :] if item_id not in items: image = images.get(item_id) if image: image = Image(path=image) items[item_id] = DatasetItem( id=item_id, subset=self._subset, media=image, annotations=annotations ) elif len(pair) == 3: image1, id1 = self.get_image_name(pair[0], pair[1]) image2, id2 = self.get_image_name(pair[0], pair[2]) label = get_label_id(pair[0]) if id1 not in items: annotations = [] annotations.append(Label(label)) image = images.get(image1) if image: image = Image(path=image) items[id1] = DatasetItem( id=id1, subset=self._subset, media=image, annotations=annotations ) if id2 not in items: annotations = [] annotations.append(Label(label)) image = images.get(image2) if image: image = Image(path=image) items[id2] = DatasetItem( id=id2, subset=self._subset, media=image, annotations=annotations ) # pairs form a directed graph if not items[id1].annotations[0].attributes.get("positive_pairs"): items[id1].annotations[0].attributes["positive_pairs"] = [] items[id1].annotations[0].attributes["positive_pairs"].append(image2) elif len(pair) == 4: image1, id1 = self.get_image_name(pair[0], pair[1]) if pair[2] == "-": image2 = pair[3] id2 = pair[3] else: image2, id2 = self.get_image_name(pair[2], pair[3]) if id1 not in items: annotations = [] label = get_label_id(pair[0]) annotations.append(Label(label)) image = images.get(image1) if image: image = Image(path=image) items[id1] = DatasetItem( id=id1, subset=self._subset, media=image, annotations=annotations ) if id2 not in items: annotations = [] if pair[2] != "-": label = get_label_id(pair[2]) annotations.append(Label(label)) image = images.get(image2) if image: image = Image(path=image) items[id2] = DatasetItem( id=id2, subset=self._subset, media=image, annotations=annotations ) # pairs form a directed graph if not items[id1].annotations[0].attributes.get("negative_pairs"): items[id1].annotations[0].attributes["negative_pairs"] = [] items[id1].annotations[0].attributes["negative_pairs"].append(image2) landmarks_file = osp.join(self._annotations_dir, LfwPath.LANDMARKS_FILE) if osp.isfile(landmarks_file): with open(landmarks_file, encoding="utf-8") as f: for line in f: line = line.split("\t") item_id = osp.splitext(line[0])[0] objects = item_id.split("/") if 1 < len(objects): label_name = objects[0] label = get_label_id(label_name) if label is not None: item_id = item_id[len(label_name) + 1 :] if item_id not in items: items[item_id] = DatasetItem( id=item_id, subset=self._subset, image=osp.join(self._images_dir, line[0]), ) annotations = items[item_id].annotations annotations.append(Points([float(p) for p in line[1:]], label=label)) return items
[docs] @staticmethod def get_image_name(person, image_id): image, item_id = "", "" try: image_id = int(image_id) image = "{}/{}_{:04d}".format(person, person, image_id) item_id = "{}_{:04d}".format(person, image_id) except ValueError: image = "{}/{}".format(person, image_id) item_id = image_id return image, item_id
[docs]class LfwImporter(Importer):
[docs] @classmethod def detect(cls, context: FormatDetectionContext) -> None: context.require_file(f"{LfwPath.ANNOTATION_DIR}/{LfwPath.PAIRS_FILE}")
[docs] @classmethod def find_sources(cls, path): base, ext = osp.splitext(LfwPath.PAIRS_FILE) return cls._find_sources_recursive( path, ext, "lfw", filename=base, dirname=LfwPath.ANNOTATION_DIR )
[docs]class LfwConverter(Converter): DEFAULT_IMAGE_EXT = LfwPath.IMAGE_EXT
[docs] def apply(self): if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image): raise MediaTypeError("Media type is not an image") os.makedirs(self._save_dir, exist_ok=True) if self._save_dataset_meta: self._save_meta_file(self._save_dir) for subset_name, subset in self._extractor.subsets().items(): label_categories = self._extractor.categories()[AnnotationType.label] labels = {label.name: 0 for label in label_categories} positive_pairs = [] negative_pairs = [] neutral_items = [] landmarks = [] included_items = [] for item in subset: anns = [ann for ann in item.annotations if ann.type == AnnotationType.label] label, label_name = None, None if anns: label = anns[0] label_name = label_categories[anns[0].label].name labels[label_name] += 1 if self._save_media and item.media: subdir = osp.join(subset_name, LfwPath.IMAGES_DIR) if label_name: subdir = osp.join(subdir, label_name) self._save_image(item, subdir=subdir) if label is not None: person1 = label_name num1 = item.id if num1.startswith(person1): num1 = int(num1.replace(person1, "")[1:]) curr_item = person1 + "/" + str(num1) if "positive_pairs" in label.attributes: if curr_item not in included_items: included_items.append(curr_item) for pair in label.attributes["positive_pairs"]: search = LfwPath.PATTERN.search(pair) if search: num2 = search.groups()[1] num2 = int(num2) else: num2 = pair if num2.startswith(person1): num2 = num2.replace(person1, "")[1:] curr_item = person1 + "/" + str(num2) if curr_item not in included_items: included_items.append(curr_item) positive_pairs.append("%s\t%s\t%s" % (person1, num1, num2)) if "negative_pairs" in label.attributes: if curr_item not in included_items: included_items.append(curr_item) for pair in label.attributes["negative_pairs"]: search = LfwPath.PATTERN.search(pair) curr_item = "" if search: person2, num2 = search.groups() num2 = int(num2) curr_item += person2 + "/" else: person2 = "-" num2 = pair objects = pair.split("/") if 1 < len(objects) and objects[0] in labels: person2 = objects[0] num2 = pair.replace(person2, "")[1:] curr_item += person2 + "/" curr_item += str(num2) if curr_item not in included_items: included_items.append(curr_item) negative_pairs.append("%s\t%s\t%s\t%s" % (person1, num1, person2, num2)) if ( "positive_pairs" not in label.attributes and "negative_pairs" not in label.attributes and curr_item not in included_items ): neutral_items.append("%s/%s" % (person1, item.id)) included_items.append(curr_item) elif item.id not in included_items: neutral_items.append(item.id) included_items.append(item.id) item_landmarks = [p for p in item.annotations if p.type == AnnotationType.points] for landmark in item_landmarks: landmarks.append( "%s\t%s" % (item.id + LfwPath.IMAGE_EXT, "\t".join(str(p) for p in landmark.points)) ) annotations_dir = osp.join(self._save_dir, subset_name, LfwPath.ANNOTATION_DIR) pairs_file = osp.join(annotations_dir, LfwPath.PAIRS_FILE) os.makedirs(osp.dirname(pairs_file), exist_ok=True) with open(pairs_file, "w", encoding="utf-8") as f: f.writelines(["%s\n" % pair for pair in positive_pairs]) f.writelines(["%s\n" % pair for pair in negative_pairs]) f.writelines(["%s\n" % item for item in neutral_items]) if landmarks: landmarks_file = osp.join(annotations_dir, LfwPath.LANDMARKS_FILE) with open(landmarks_file, "w", encoding="utf-8") as f: f.writelines(["%s\n" % landmark for landmark in landmarks]) if labels: people_file = osp.join(annotations_dir, LfwPath.PEOPLE_FILE) with open(people_file, "w", encoding="utf-8") as f: f.writelines(["%s\t%d\n" % (label, labels[label]) for label in labels])