Source code for datumaro.plugins.lfw_format
# Copyright (C) 2020-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT
import os
import os.path as osp
import re
from datumaro.components.annotation import AnnotationType, Label, LabelCategories, Points
from datumaro.components.converter import Converter
from datumaro.components.errors import MediaTypeError
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.media import Image
from datumaro.util.image import find_images
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file
[docs]class LfwPath:
IMAGES_DIR = "images"
ANNOTATION_DIR = "annotations"
LANDMARKS_FILE = "landmarks.txt"
PAIRS_FILE = "pairs.txt"
PEOPLE_FILE = "people.txt"
IMAGE_EXT = ".jpg"
PATTERN = re.compile(r"([\w-]+)_([-\d]+)")
[docs]class LfwExtractor(SourceExtractor):
[docs] def __init__(self, path, subset=None):
if not osp.isfile(path):
raise FileNotFoundError("Can't read annotation file '%s'" % path)
if not subset:
subset = osp.basename(osp.dirname(osp.dirname(path)))
super().__init__(subset=subset)
self._dataset_dir = osp.dirname(osp.dirname(osp.dirname(path)))
self._annotations_dir = osp.dirname(path)
self._images_dir = osp.join(self._dataset_dir, self._subset, LfwPath.IMAGES_DIR)
people_file = osp.join(osp.dirname(path), LfwPath.PEOPLE_FILE)
self._categories = self._load_categories(people_file)
self._items = list(self._load_items(path).values())
def _load_categories(self, path):
if has_meta_file(self._dataset_dir):
return {
AnnotationType.label: LabelCategories.from_iterable(
parse_meta_file(self._dataset_dir).keys()
)
}
label_cat = LabelCategories()
if osp.isfile(path):
with open(path, encoding="utf-8") as labels_file:
for line in labels_file:
objects = line.strip().split("\t")
if len(objects) == 2:
label_cat.add(objects[0])
return {AnnotationType.label: label_cat}
def _load_items(self, path):
items = {}
label_categories = self._categories.get(AnnotationType.label)
if osp.isdir(self._images_dir):
images = {
osp.splitext(osp.relpath(p, self._images_dir))[0].replace("\\", "/"): p
for p in find_images(self._images_dir, recursive=True)
}
else:
images = {}
with open(path, encoding="utf-8") as f:
def get_label_id(label_name):
if not label_name:
return None
label_id = label_categories.find(label_name)[0]
if label_id is None:
label_id = label_categories.add(label_name)
return label_id
for line in f:
pair = line.strip().split("\t")
if len(pair) == 1 and pair[0] != "":
annotations = []
image = pair[0]
item_id = pair[0]
objects = item_id.split("/")
if 1 < len(objects):
label_name = objects[0]
label = get_label_id(label_name)
if label is not None:
annotations.append(Label(label))
item_id = item_id[len(label_name) + 1 :]
if item_id not in items:
image = images.get(item_id)
if image:
image = Image(path=image)
items[item_id] = DatasetItem(
id=item_id, subset=self._subset, media=image, annotations=annotations
)
elif len(pair) == 3:
image1, id1 = self.get_image_name(pair[0], pair[1])
image2, id2 = self.get_image_name(pair[0], pair[2])
label = get_label_id(pair[0])
if id1 not in items:
annotations = []
annotations.append(Label(label))
image = images.get(image1)
if image:
image = Image(path=image)
items[id1] = DatasetItem(
id=id1, subset=self._subset, media=image, annotations=annotations
)
if id2 not in items:
annotations = []
annotations.append(Label(label))
image = images.get(image2)
if image:
image = Image(path=image)
items[id2] = DatasetItem(
id=id2, subset=self._subset, media=image, annotations=annotations
)
# pairs form a directed graph
if not items[id1].annotations[0].attributes.get("positive_pairs"):
items[id1].annotations[0].attributes["positive_pairs"] = []
items[id1].annotations[0].attributes["positive_pairs"].append(image2)
elif len(pair) == 4:
image1, id1 = self.get_image_name(pair[0], pair[1])
if pair[2] == "-":
image2 = pair[3]
id2 = pair[3]
else:
image2, id2 = self.get_image_name(pair[2], pair[3])
if id1 not in items:
annotations = []
label = get_label_id(pair[0])
annotations.append(Label(label))
image = images.get(image1)
if image:
image = Image(path=image)
items[id1] = DatasetItem(
id=id1, subset=self._subset, media=image, annotations=annotations
)
if id2 not in items:
annotations = []
if pair[2] != "-":
label = get_label_id(pair[2])
annotations.append(Label(label))
image = images.get(image2)
if image:
image = Image(path=image)
items[id2] = DatasetItem(
id=id2, subset=self._subset, media=image, annotations=annotations
)
# pairs form a directed graph
if not items[id1].annotations[0].attributes.get("negative_pairs"):
items[id1].annotations[0].attributes["negative_pairs"] = []
items[id1].annotations[0].attributes["negative_pairs"].append(image2)
landmarks_file = osp.join(self._annotations_dir, LfwPath.LANDMARKS_FILE)
if osp.isfile(landmarks_file):
with open(landmarks_file, encoding="utf-8") as f:
for line in f:
line = line.split("\t")
item_id = osp.splitext(line[0])[0]
objects = item_id.split("/")
if 1 < len(objects):
label_name = objects[0]
label = get_label_id(label_name)
if label is not None:
item_id = item_id[len(label_name) + 1 :]
if item_id not in items:
items[item_id] = DatasetItem(
id=item_id,
subset=self._subset,
image=osp.join(self._images_dir, line[0]),
)
annotations = items[item_id].annotations
annotations.append(Points([float(p) for p in line[1:]], label=label))
return items
[docs] @staticmethod
def get_image_name(person, image_id):
image, item_id = "", ""
try:
image_id = int(image_id)
image = "{}/{}_{:04d}".format(person, person, image_id)
item_id = "{}_{:04d}".format(person, image_id)
except ValueError:
image = "{}/{}".format(person, image_id)
item_id = image_id
return image, item_id
[docs]class LfwImporter(Importer):
[docs] @classmethod
def detect(cls, context: FormatDetectionContext) -> None:
context.require_file(f"{LfwPath.ANNOTATION_DIR}/{LfwPath.PAIRS_FILE}")
[docs] @classmethod
def find_sources(cls, path):
base, ext = osp.splitext(LfwPath.PAIRS_FILE)
return cls._find_sources_recursive(
path, ext, "lfw", filename=base, dirname=LfwPath.ANNOTATION_DIR
)
[docs]class LfwConverter(Converter):
DEFAULT_IMAGE_EXT = LfwPath.IMAGE_EXT
[docs] def apply(self):
if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image):
raise MediaTypeError("Media type is not an image")
os.makedirs(self._save_dir, exist_ok=True)
if self._save_dataset_meta:
self._save_meta_file(self._save_dir)
for subset_name, subset in self._extractor.subsets().items():
label_categories = self._extractor.categories()[AnnotationType.label]
labels = {label.name: 0 for label in label_categories}
positive_pairs = []
negative_pairs = []
neutral_items = []
landmarks = []
included_items = []
for item in subset:
anns = [ann for ann in item.annotations if ann.type == AnnotationType.label]
label, label_name = None, None
if anns:
label = anns[0]
label_name = label_categories[anns[0].label].name
labels[label_name] += 1
if self._save_media and item.media:
subdir = osp.join(subset_name, LfwPath.IMAGES_DIR)
if label_name:
subdir = osp.join(subdir, label_name)
self._save_image(item, subdir=subdir)
if label is not None:
person1 = label_name
num1 = item.id
if num1.startswith(person1):
num1 = int(num1.replace(person1, "")[1:])
curr_item = person1 + "/" + str(num1)
if "positive_pairs" in label.attributes:
if curr_item not in included_items:
included_items.append(curr_item)
for pair in label.attributes["positive_pairs"]:
search = LfwPath.PATTERN.search(pair)
if search:
num2 = search.groups()[1]
num2 = int(num2)
else:
num2 = pair
if num2.startswith(person1):
num2 = num2.replace(person1, "")[1:]
curr_item = person1 + "/" + str(num2)
if curr_item not in included_items:
included_items.append(curr_item)
positive_pairs.append("%s\t%s\t%s" % (person1, num1, num2))
if "negative_pairs" in label.attributes:
if curr_item not in included_items:
included_items.append(curr_item)
for pair in label.attributes["negative_pairs"]:
search = LfwPath.PATTERN.search(pair)
curr_item = ""
if search:
person2, num2 = search.groups()
num2 = int(num2)
curr_item += person2 + "/"
else:
person2 = "-"
num2 = pair
objects = pair.split("/")
if 1 < len(objects) and objects[0] in labels:
person2 = objects[0]
num2 = pair.replace(person2, "")[1:]
curr_item += person2 + "/"
curr_item += str(num2)
if curr_item not in included_items:
included_items.append(curr_item)
negative_pairs.append("%s\t%s\t%s\t%s" % (person1, num1, person2, num2))
if (
"positive_pairs" not in label.attributes
and "negative_pairs" not in label.attributes
and curr_item not in included_items
):
neutral_items.append("%s/%s" % (person1, item.id))
included_items.append(curr_item)
elif item.id not in included_items:
neutral_items.append(item.id)
included_items.append(item.id)
item_landmarks = [p for p in item.annotations if p.type == AnnotationType.points]
for landmark in item_landmarks:
landmarks.append(
"%s\t%s"
% (item.id + LfwPath.IMAGE_EXT, "\t".join(str(p) for p in landmark.points))
)
annotations_dir = osp.join(self._save_dir, subset_name, LfwPath.ANNOTATION_DIR)
pairs_file = osp.join(annotations_dir, LfwPath.PAIRS_FILE)
os.makedirs(osp.dirname(pairs_file), exist_ok=True)
with open(pairs_file, "w", encoding="utf-8") as f:
f.writelines(["%s\n" % pair for pair in positive_pairs])
f.writelines(["%s\n" % pair for pair in negative_pairs])
f.writelines(["%s\n" % item for item in neutral_items])
if landmarks:
landmarks_file = osp.join(annotations_dir, LfwPath.LANDMARKS_FILE)
with open(landmarks_file, "w", encoding="utf-8") as f:
f.writelines(["%s\n" % landmark for landmark in landmarks])
if labels:
people_file = osp.join(annotations_dir, LfwPath.PEOPLE_FILE)
with open(people_file, "w", encoding="utf-8") as f:
f.writelines(["%s\t%d\n" % (label, labels[label]) for label in labels])