Source code for datumaro.plugins.kitti_raw_format.converter

# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

from copy import deepcopy
# Disable B406: import_xml_sax - the library is used for writing
from xml.sax.saxutils import XMLGenerator  # nosec
import logging as log
import os
import os.path as osp

from datumaro.components.annotation import AnnotationType, LabelCategories
from datumaro.components.converter import Converter
from datumaro.components.dataset import ItemStatus
from datumaro.components.extractor import DatasetItem
from datumaro.util import cast
from datumaro.util.image import find_images

from .format import KittiRawPath, OcclusionStates, PoseStates, TruncationStates


class _XmlAnnotationWriter:
    # Format constants
    _tracking_level = 0

    _tracklets_class_id = 0
    _tracklets_version = 0

    _tracklet_class_id = 1
    _tracklet_version = 1

    _poses_class_id = 2
    _poses_version = 0

    _pose_class_id = 3
    _pose_version = 1

    # XML headers
    _header = """<?xml version="1.0" encoding="UTF-8" standalone="yes"?>"""
    _doctype = "<!DOCTYPE boost_serialization>"

    def __init__(self, file, tracklets):
        self._file = file
        self._tracklets = tracklets

        self._xmlgen = XMLGenerator(self._file, encoding='utf-8')
        self._level = 0

        # See reference for section headers here:
        # https://www.boost.org/doc/libs/1_40_0/libs/serialization/doc/traits.html
        # XML archives have regular structure, so we only include headers once
        self._add_tracklet_header = True
        self._add_poses_header = True
        self._add_pose_header = True

    def _indent(self, newline=True):
        if newline:
            self._xmlgen.ignorableWhitespace("\n")
        self._xmlgen.ignorableWhitespace("  " * self._level)

    def _add_headers(self):
        self._file.write(self._header)

        self._indent(newline=True)
        self._file.write(self._doctype)

    def _open_serialization(self):
        self._indent(newline=True)
        self._xmlgen.startElement("boost_serialization", {
            "version": "9", "signature": "serialization::archive"
        })

    def _close_serialization(self):
        self._indent(newline=True)
        self._xmlgen.endElement("boost_serialization")

    def _add_count(self, count):
        self._indent(newline=True)
        self._xmlgen.startElement("count", {})
        self._xmlgen.characters(str(count))
        self._xmlgen.endElement("count")

    def _add_item_version(self, version):
        self._indent(newline=True)
        self._xmlgen.startElement("item_version", {})
        self._xmlgen.characters(str(version))
        self._xmlgen.endElement("item_version")

    def _open_tracklets(self, tracklets):
        self._indent(newline=True)
        self._xmlgen.startElement("tracklets", {
            "version": str(self._tracklets_version),
            "tracking_level": str(self._tracking_level),
            "class_id": str(self._tracklets_class_id),
        })
        self._level += 1
        self._add_count(len(tracklets))
        self._add_item_version(self._tracklet_version)

    def _close_tracklets(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("tracklets")

    def _open_tracklet(self):
        self._indent(newline=True)
        if self._add_tracklet_header:
            self._xmlgen.startElement("item", {
                "version": str(self._tracklet_class_id),
                "tracking_level": str(self._tracking_level),
                "class_id": str(self._tracklet_class_id),
            })
            self._add_tracklet_header = False
        else:
            self._xmlgen.startElement("item", {})
        self._level += 1

    def _close_tracklet(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("item")

    def _add_tracklet(self, tracklet):
        self._open_tracklet()

        for key, value in tracklet.items():
            if key == "poses":
                self._add_poses(value)
            elif key == "attributes":
                self._add_attributes(value)
            else:
                self._indent(newline=True)
                self._xmlgen.startElement(key, {})
                self._xmlgen.characters(str(value))
                self._xmlgen.endElement(key)

        self._close_tracklet()

    def _open_poses(self, poses):
        self._indent(newline=True)
        if self._add_poses_header:
            self._xmlgen.startElement("poses", {
                "version": str(self._poses_version),
                "tracking_level": str(self._tracking_level),
                "class_id": str(self._poses_class_id),
            })
            self._add_poses_header = False
        else:
            self._xmlgen.startElement("poses", {})
        self._level += 1

        self._add_count(len(poses))
        self._add_item_version(self._poses_version)

    def _close_poses(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("poses")

    def _add_poses(self, poses):
        self._open_poses(poses)

        for pose in poses:
            self._add_pose(pose)

        self._close_poses()

    def _open_pose(self):
        self._indent(newline=True)
        if self._add_pose_header:
            self._xmlgen.startElement("item", {
                "version": str(self._pose_version),
                "tracking_level": str(self._tracking_level),
                "class_id": str(self._pose_class_id),
            })
            self._add_pose_header = False
        else:
            self._xmlgen.startElement("item", {})
        self._level += 1

    def _close_pose(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("item")

    def _add_pose(self, pose):
        self._open_pose()

        for key, value in pose.items():
            if key == 'attributes':
                self._add_attributes(value)
            elif key != 'frame_id':
                self._indent(newline=True)
                self._xmlgen.startElement(key, {})
                self._xmlgen.characters(str(value))
                self._xmlgen.endElement(key)

        self._close_pose()

    def _open_attributes(self):
        self._indent(newline=True)
        self._xmlgen.startElement("attributes", {})
        self._level += 1

    def _close_attributes(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("attributes")

    def _add_attributes(self, attributes):
        self._open_attributes()

        for name, value in attributes.items():
            self._add_attribute(name, value)

        self._close_attributes()

    def _open_attribute(self):
        self._indent(newline=True)
        self._xmlgen.startElement("attribute", {})
        self._level += 1

    def _close_attribute(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("attribute")

    def _add_attribute(self, name, value):
        self._open_attribute()

        self._indent(newline=True)
        self._xmlgen.startElement("name", {})
        self._xmlgen.characters(name)
        self._xmlgen.endElement("name")

        self._xmlgen.startElement("value", {})
        self._xmlgen.characters(str(value))
        self._xmlgen.endElement("value")

        self._close_attribute()

    def write(self):
        self._add_headers()
        self._open_serialization()

        self._open_tracklets(self._tracklets)

        for tracklet in self._tracklets:
            self._add_tracklet(tracklet)

        self._close_tracklets()

        self._close_serialization()


[docs]class KittiRawConverter(Converter): DEFAULT_IMAGE_EXT = ".jpg"
[docs] @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) parser.add_argument('--reindex', action='store_true', help="Assign new indices to frames and tracks. " "Allows annotations without 'track_id' (default: %(default)s)") parser.add_argument('--allow-attrs', action='store_true', help="Allow writing annotation attributes (default: %(default)s)") return parser
[docs] def __init__(self, extractor, save_dir, reindex=False, allow_attrs=False, **kwargs): super().__init__(extractor, save_dir, **kwargs) self._reindex = reindex self._builtin_attrs = \ KittiRawPath.BUILTIN_ATTRS | KittiRawPath.SPECIAL_ATTRS self._allow_attrs = allow_attrs
def _create_tracklets(self, subset): tracks = {} # track_id -> track name_mapping = {} # frame_id -> name for frame_id, item in enumerate(subset): frame_id = self._write_item(item, frame_id) if frame_id in name_mapping: raise Exception( "Item %s: frame id %s is repeated in the dataset" % \ (item.id, frame_id)) name_mapping[frame_id] = item.id for ann in item.annotations: if ann.type != AnnotationType.cuboid_3d: continue if ann.label is None: log.warning("Item %s: skipping a %s%s with no label", item.id, ann.type.name, '(#%s) ' % ann.id if ann.id is not None else '') continue label = self._get_label(ann.label).name track_id = cast(ann.attributes.get('track_id'), int, None) if self._reindex and track_id is None: # In this format, track id is not used for anything except # annotation grouping. So we only need to pick a definitely # unused id. A negative one, for example. track_id = -(len(tracks) + 1) if track_id is None: raise Exception("Item %s: expected track annotations " "having 'track_id' (integer) attribute. " "Use --reindex to export single shapes." % item.id) track = tracks.get(track_id) if not track: track = { "objectType": label, "h": ann.scale[1], "w": ann.scale[0], "l": ann.scale[2], "first_frame": frame_id, "poses": [], "finished": 1 # keep last } tracks[track_id] = track else: if [track['w'], track['h'], track['l']] != ann.scale: # Tracks have fixed scale in the format raise Exception("Item %s: mismatching track shapes, " \ "track id %s" % (item.id, track_id)) if track['objectType'] != label: raise Exception("Item %s: mismatching track labels, " \ "track id %s: %s vs. %s" % \ (item.id, track_id, track['objectType'], label)) # If there is a skip in track frames, add missing as outside if frame_id != track['poses'][-1]['frame_id'] + 1: last_key_pose = track['poses'][-1] last_keyframe_id = last_key_pose['frame_id'] last_key_pose['occlusion_kf'] = 1 for i in range(last_keyframe_id + 1, frame_id): pose = deepcopy(last_key_pose) pose['occlusion'] = OcclusionStates.OCCLUSION_UNSET pose['truncation'] = TruncationStates.OUT_IMAGE pose['frame_id'] = i track['poses'].append(pose) occlusion = OcclusionStates.VISIBLE if 'occlusion' in ann.attributes: occlusion = OcclusionStates( ann.attributes['occlusion'].upper()) elif 'occluded' in ann.attributes: if ann.attributes['occluded']: occlusion = OcclusionStates.PARTLY truncation = TruncationStates.IN_IMAGE if 'truncation' in ann.attributes: truncation = TruncationStates( ann.attributes['truncation'].upper()) pose = { "tx": ann.position[0], "ty": ann.position[1], "tz": ann.position[2], "rx": ann.rotation[0], "ry": ann.rotation[1], "rz": ann.rotation[2], "state": PoseStates.LABELED.value, "occlusion": occlusion.value, "occlusion_kf": \ int(ann.attributes.get("keyframe", False) is True), "truncation": truncation.value, "amt_occlusion": -1, "amt_border_l": -1, "amt_border_r": -1, "amt_occlusion_kf": -1, "amt_border_kf": -1, "frame_id": frame_id, } if self._allow_attrs: attributes = {} for name, value in ann.attributes.items(): if name in self._builtin_attrs: continue if isinstance(value, bool): value = 'true' if value else 'false' attributes[name] = value pose["attributes"] = attributes track["poses"].append(pose) self._write_name_mapping(name_mapping) return [e[1] for e in sorted(tracks.items(), key=lambda e: e[0])] def _write_name_mapping(self, name_mapping): with open(osp.join(self._save_dir, KittiRawPath.NAME_MAPPING_FILE), 'w', encoding='utf-8') as f: f.writelines('%s %s\n' % (frame_id, name) for frame_id, name in name_mapping.items()) def _get_label(self, label_id): if label_id is None: return "" label_cat = self._extractor.categories().get( AnnotationType.label, LabelCategories()) return label_cat.items[label_id] def _write_item(self, item, index): if not self._reindex: index = cast(item.attributes.get('frame'), int, index) if self._save_images: if item.has_point_cloud: self._save_point_cloud(item, subdir=KittiRawPath.PCD_DIR) images = sorted(item.related_images, key=lambda img: img.path) for i, image in enumerate(images): if image.has_data: image.save(osp.join(self._save_dir, KittiRawPath.IMG_DIR_PREFIX + ('%02d' % i), 'data', item.id + self._find_image_ext(image))) else: log.debug("Item '%s' has no image info", item.id) return index
[docs] def apply(self): os.makedirs(self._save_dir, exist_ok=True) if self._save_dataset_meta: self._save_meta_file(self._save_dir) if 1 < len(self._extractor.subsets()): log.warning("Kitti RAW format supports only a single " "subset. Subset information will be ignored on export.") tracklets = self._create_tracklets(self._extractor) with open(osp.join(self._save_dir, KittiRawPath.ANNO_FILE), 'w', encoding='utf-8') as f: writer = _XmlAnnotationWriter(f, tracklets) writer.write()
[docs] @classmethod def patch(cls, dataset, patch, save_dir, **kwargs): conv = cls(patch.as_dataset(dataset), save_dir=save_dir, **kwargs) conv.apply() pcd_dir = osp.abspath(osp.join(save_dir, KittiRawPath.PCD_DIR)) for (item_id, subset), status in patch.updated_items.items(): if status != ItemStatus.removed: item = patch.data.get(item_id, subset) else: item = DatasetItem(item_id, subset=subset) if not (status == ItemStatus.removed or not item.has_point_cloud): continue pcd_path = osp.join(pcd_dir, conv._make_pcd_filename(item)) if osp.isfile(pcd_path): os.unlink(pcd_path) for d in os.listdir(save_dir): image_dir = osp.join(save_dir, d, 'data', osp.dirname(item.id)) if d.startswith(KittiRawPath.IMG_DIR_PREFIX) and \ osp.isdir(image_dir): for p in find_images(image_dir): if osp.splitext(osp.basename(p))[0] == \ osp.basename(item.id): os.unlink(p)