Source code for datumaro.plugins.datumaro_format.extractor

# Copyright (C) 2019-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp

from datumaro.components.annotation import (
    AnnotationType,
    Bbox,
    Caption,
    Cuboid3d,
    Label,
    LabelCategories,
    MaskCategories,
    Points,
    PointsCategories,
    Polygon,
    PolyLine,
    RleMask,
)
from datumaro.components.errors import DatasetImportError
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.media import Image, MediaElement, PointCloud
from datumaro.util import parse_json, parse_json_file

from .format import DatumaroPath


[docs]class DatumaroExtractor(SourceExtractor):
[docs] def __init__(self, path): assert osp.isfile(path), path rootpath = "" if path.endswith(osp.join(DatumaroPath.ANNOTATIONS_DIR, osp.basename(path))): rootpath = path.rsplit(DatumaroPath.ANNOTATIONS_DIR, maxsplit=1)[0] images_dir = "" if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.IMAGES_DIR)): images_dir = osp.join(rootpath, DatumaroPath.IMAGES_DIR) self._images_dir = images_dir pcd_dir = "" if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.PCD_DIR)): pcd_dir = osp.join(rootpath, DatumaroPath.PCD_DIR) self._pcd_dir = pcd_dir related_images_dir = "" if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.RELATED_IMAGES_DIR)): related_images_dir = osp.join(rootpath, DatumaroPath.RELATED_IMAGES_DIR) self._related_images_dir = related_images_dir super().__init__(subset=osp.splitext(osp.basename(path))[0]) parsed_anns = parse_json_file(path) self._categories = self._load_categories(parsed_anns) self._items = self._load_items(parsed_anns)
@staticmethod def _load_categories(parsed): categories = {} parsed_label_cat = parsed["categories"].get(AnnotationType.label.name) if parsed_label_cat: label_categories = LabelCategories(attributes=parsed_label_cat.get("attributes", [])) for item in parsed_label_cat["labels"]: label_categories.add( item["name"], parent=item["parent"], attributes=item.get("attributes", []) ) categories[AnnotationType.label] = label_categories parsed_mask_cat = parsed["categories"].get(AnnotationType.mask.name) if parsed_mask_cat: colormap = {} for item in parsed_mask_cat["colormap"]: colormap[int(item["label_id"])] = (item["r"], item["g"], item["b"]) mask_categories = MaskCategories(colormap=colormap) categories[AnnotationType.mask] = mask_categories parsed_points_cat = parsed["categories"].get(AnnotationType.points.name) if parsed_points_cat: point_categories = PointsCategories() for item in parsed_points_cat["items"]: point_categories.add(int(item["label_id"]), item["labels"], joints=item["joints"]) categories[AnnotationType.points] = point_categories return categories def _load_items(self, parsed): items = [] for item_desc in parsed["items"]: item_id = item_desc["id"] media = None image_info = item_desc.get("image") if image_info: image_filename = image_info.get("path") or item_id + DatumaroPath.IMAGE_EXT image_path = osp.join(self._images_dir, self._subset, image_filename) if not osp.isfile(image_path): # backward compatibility old_image_path = osp.join(self._images_dir, image_filename) if osp.isfile(old_image_path): image_path = old_image_path media = Image(path=image_path, size=image_info.get("size")) self._media_type = Image pcd_info = item_desc.get("point_cloud") if media and pcd_info: raise DatasetImportError("Dataset cannot contain multiple media types") if pcd_info: pcd_path = pcd_info.get("path") point_cloud = osp.join(self._pcd_dir, self._subset, pcd_path) related_images = None ri_info = item_desc.get("related_images") if ri_info: related_images = [ Image( size=ri.get("size"), path=osp.join( self._related_images_dir, self._subset, item_id, ri.get("path") ), ) for ri in ri_info ] media = PointCloud(point_cloud, extra_images=related_images) self._media_type = PointCloud media_desc = item_desc.get("media") if not media and media_desc and media_desc.get("path"): media = MediaElement(path=media_desc.get("path")) self._media_type = MediaElement annotations = self._load_annotations(item_desc) item = DatasetItem( id=item_id, subset=self._subset, annotations=annotations, media=media, attributes=item_desc.get("attr"), ) items.append(item) return items @staticmethod def _load_annotations(item): parsed = item["annotations"] loaded = [] for ann in parsed: ann_id = ann.get("id") ann_type = AnnotationType[ann["type"]] attributes = ann.get("attributes") group = ann.get("group") label_id = ann.get("label_id") z_order = ann.get("z_order") points = ann.get("points") if ann_type == AnnotationType.label: loaded.append(Label(label=label_id, id=ann_id, attributes=attributes, group=group)) elif ann_type == AnnotationType.mask: rle = ann["rle"] rle["counts"] = rle["counts"].encode("ascii") loaded.append( RleMask( rle=rle, label=label_id, id=ann_id, attributes=attributes, group=group, z_order=z_order, ) ) elif ann_type == AnnotationType.polyline: loaded.append( PolyLine( points, label=label_id, id=ann_id, attributes=attributes, group=group, z_order=z_order, ) ) elif ann_type == AnnotationType.polygon: loaded.append( Polygon( points, label=label_id, id=ann_id, attributes=attributes, group=group, z_order=z_order, ) ) elif ann_type == AnnotationType.bbox: x, y, w, h = ann["bbox"] loaded.append( Bbox( x, y, w, h, label=label_id, id=ann_id, attributes=attributes, group=group, z_order=z_order, ) ) elif ann_type == AnnotationType.points: loaded.append( Points( points, label=label_id, id=ann_id, attributes=attributes, group=group, z_order=z_order, ) ) elif ann_type == AnnotationType.caption: caption = ann.get("caption") loaded.append(Caption(caption, id=ann_id, attributes=attributes, group=group)) elif ann_type == AnnotationType.cuboid_3d: loaded.append( Cuboid3d( ann.get("position"), ann.get("rotation"), ann.get("scale"), label=label_id, id=ann_id, attributes=attributes, group=group, ) ) else: raise NotImplementedError() return loaded
[docs]class DatumaroImporter(Importer):
[docs] @classmethod def detect(cls, context: FormatDetectionContext) -> None: annot_file = context.require_file("annotations/*.json") with context.probe_text_file( annot_file, 'must be a JSON object with "categories" ' 'and "items" keys', ) as f: contents = parse_json(f.read()) if not {"categories", "items"} <= contents.keys(): raise Exception
[docs] @classmethod def find_sources(cls, path): return cls._find_sources_recursive(path, ".json", "datumaro", dirname="annotations")