# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT
import logging as log
import os
import os.path as osp
from copy import deepcopy
# Disable B406: import_xml_sax - the library is used for writing
from xml.sax.saxutils import XMLGenerator  # nosec
from datumaro.components.annotation import AnnotationType, LabelCategories
from datumaro.components.converter import Converter
from datumaro.components.dataset import ItemStatus
from datumaro.components.errors import MediaTypeError
from datumaro.components.extractor import DatasetItem
from datumaro.components.media import PointCloud
from datumaro.util import cast
from datumaro.util.image import find_images
from .format import KittiRawPath, OcclusionStates, PoseStates, TruncationStates
class _XmlAnnotationWriter:
    # Format constants
    _tracking_level = 0
    _tracklets_class_id = 0
    _tracklets_version = 0
    _tracklet_class_id = 1
    _tracklet_version = 1
    _poses_class_id = 2
    _poses_version = 0
    _pose_class_id = 3
    _pose_version = 1
    # XML headers
    _header = """<?xml version="1.0" encoding="UTF-8" standalone="yes"?>"""
    _doctype = "<!DOCTYPE boost_serialization>"
    def __init__(self, file, tracklets):
        self._file = file
        self._tracklets = tracklets
        self._xmlgen = XMLGenerator(self._file, encoding="utf-8")
        self._level = 0
        # See reference for section headers here:
        # https://www.boost.org/doc/libs/1_40_0/libs/serialization/doc/traits.html
        # XML archives have regular structure, so we only include headers once
        self._add_tracklet_header = True
        self._add_poses_header = True
        self._add_pose_header = True
    def _indent(self, newline=True):
        if newline:
            self._xmlgen.ignorableWhitespace("\n")
        self._xmlgen.ignorableWhitespace("  " * self._level)
    def _add_headers(self):
        self._file.write(self._header)
        self._indent(newline=True)
        self._file.write(self._doctype)
    def _open_serialization(self):
        self._indent(newline=True)
        self._xmlgen.startElement(
            "boost_serialization", {"version": "9", "signature": "serialization::archive"}
        )
    def _close_serialization(self):
        self._indent(newline=True)
        self._xmlgen.endElement("boost_serialization")
    def _add_count(self, count):
        self._indent(newline=True)
        self._xmlgen.startElement("count", {})
        self._xmlgen.characters(str(count))
        self._xmlgen.endElement("count")
    def _add_item_version(self, version):
        self._indent(newline=True)
        self._xmlgen.startElement("item_version", {})
        self._xmlgen.characters(str(version))
        self._xmlgen.endElement("item_version")
    def _open_tracklets(self, tracklets):
        self._indent(newline=True)
        self._xmlgen.startElement(
            "tracklets",
            {
                "version": str(self._tracklets_version),
                "tracking_level": str(self._tracking_level),
                "class_id": str(self._tracklets_class_id),
            },
        )
        self._level += 1
        self._add_count(len(tracklets))
        self._add_item_version(self._tracklet_version)
    def _close_tracklets(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("tracklets")
    def _open_tracklet(self):
        self._indent(newline=True)
        if self._add_tracklet_header:
            self._xmlgen.startElement(
                "item",
                {
                    "version": str(self._tracklet_class_id),
                    "tracking_level": str(self._tracking_level),
                    "class_id": str(self._tracklet_class_id),
                },
            )
            self._add_tracklet_header = False
        else:
            self._xmlgen.startElement("item", {})
        self._level += 1
    def _close_tracklet(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("item")
    def _add_tracklet(self, tracklet):
        self._open_tracklet()
        for key, value in tracklet.items():
            if key == "poses":
                self._add_poses(value)
            elif key == "attributes":
                self._add_attributes(value)
            else:
                self._indent(newline=True)
                self._xmlgen.startElement(key, {})
                self._xmlgen.characters(str(value))
                self._xmlgen.endElement(key)
        self._close_tracklet()
    def _open_poses(self, poses):
        self._indent(newline=True)
        if self._add_poses_header:
            self._xmlgen.startElement(
                "poses",
                {
                    "version": str(self._poses_version),
                    "tracking_level": str(self._tracking_level),
                    "class_id": str(self._poses_class_id),
                },
            )
            self._add_poses_header = False
        else:
            self._xmlgen.startElement("poses", {})
        self._level += 1
        self._add_count(len(poses))
        self._add_item_version(self._poses_version)
    def _close_poses(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("poses")
    def _add_poses(self, poses):
        self._open_poses(poses)
        for pose in poses:
            self._add_pose(pose)
        self._close_poses()
    def _open_pose(self):
        self._indent(newline=True)
        if self._add_pose_header:
            self._xmlgen.startElement(
                "item",
                {
                    "version": str(self._pose_version),
                    "tracking_level": str(self._tracking_level),
                    "class_id": str(self._pose_class_id),
                },
            )
            self._add_pose_header = False
        else:
            self._xmlgen.startElement("item", {})
        self._level += 1
    def _close_pose(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("item")
    def _add_pose(self, pose):
        self._open_pose()
        for key, value in pose.items():
            if key == "attributes":
                self._add_attributes(value)
            elif key != "frame_id":
                self._indent(newline=True)
                self._xmlgen.startElement(key, {})
                self._xmlgen.characters(str(value))
                self._xmlgen.endElement(key)
        self._close_pose()
    def _open_attributes(self):
        self._indent(newline=True)
        self._xmlgen.startElement("attributes", {})
        self._level += 1
    def _close_attributes(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("attributes")
    def _add_attributes(self, attributes):
        self._open_attributes()
        for name, value in attributes.items():
            self._add_attribute(name, value)
        self._close_attributes()
    def _open_attribute(self):
        self._indent(newline=True)
        self._xmlgen.startElement("attribute", {})
        self._level += 1
    def _close_attribute(self):
        self._level -= 1
        self._indent(newline=True)
        self._xmlgen.endElement("attribute")
    def _add_attribute(self, name, value):
        self._open_attribute()
        self._indent(newline=True)
        self._xmlgen.startElement("name", {})
        self._xmlgen.characters(name)
        self._xmlgen.endElement("name")
        self._xmlgen.startElement("value", {})
        self._xmlgen.characters(str(value))
        self._xmlgen.endElement("value")
        self._close_attribute()
    def write(self):
        self._add_headers()
        self._open_serialization()
        self._open_tracklets(self._tracklets)
        for tracklet in self._tracklets:
            self._add_tracklet(tracklet)
        self._close_tracklets()
        self._close_serialization()
[docs]class KittiRawConverter(Converter):
    DEFAULT_IMAGE_EXT = ".jpg"
[docs]    @classmethod
    def build_cmdline_parser(cls, **kwargs):
        parser = super().build_cmdline_parser(**kwargs)
        parser.add_argument(
            "--reindex",
            action="store_true",
            help="Assign new indices to frames and tracks. "
            "Allows annotations without 'track_id' (default: %(default)s)",
        )
        parser.add_argument(
            "--allow-attrs",
            action="store_true",
            help="Allow writing annotation attributes (default: %(default)s)",
        )
        return parser 
[docs]    def __init__(self, extractor, save_dir, reindex=False, allow_attrs=False, **kwargs):
        super().__init__(extractor, save_dir, **kwargs)
        self._reindex = reindex
        self._builtin_attrs = KittiRawPath.BUILTIN_ATTRS | KittiRawPath.SPECIAL_ATTRS
        self._allow_attrs = allow_attrs 
    def _create_tracklets(self, subset):
        tracks = {}  # track_id -> track
        name_mapping = {}  # frame_id -> name
        for frame_id, item in enumerate(subset):
            frame_id = self._write_item(item, frame_id)
            if frame_id in name_mapping:
                raise Exception(
                    "Item %s: frame id %s is repeated in the dataset" % (item.id, frame_id)
                )
            name_mapping[frame_id] = item.id
            for ann in item.annotations:
                if ann.type != AnnotationType.cuboid_3d:
                    continue
                if ann.label is None:
                    log.warning(
                        "Item %s: skipping a %s%s with no label",
                        item.id,
                        ann.type.name,
                        "(#%s) " % ann.id if ann.id is not None else "",
                    )
                    continue
                label = self._get_label(ann.label).name
                track_id = cast(ann.attributes.get("track_id"), int, None)
                if self._reindex and track_id is None:
                    # In this format, track id is not used for anything except
                    # annotation grouping. So we only need to pick a definitely
                    # unused id. A negative one, for example.
                    track_id = -(len(tracks) + 1)
                if track_id is None:
                    raise Exception(
                        "Item %s: expected track annotations "
                        "having 'track_id' (integer) attribute. "
                        "Use --reindex to export single shapes." % item.id
                    )
                track = tracks.get(track_id)
                if not track:
                    track = {
                        "objectType": label,
                        "h": ann.scale[1],
                        "w": ann.scale[0],
                        "l": ann.scale[2],
                        "first_frame": frame_id,
                        "poses": [],
                        "finished": 1,  # keep last
                    }
                    tracks[track_id] = track
                else:
                    if [track["w"], track["h"], track["l"]] != ann.scale:
                        # Tracks have fixed scale in the format
                        raise Exception(
                            "Item %s: mismatching track shapes, "
                            "track id %s" % (item.id, track_id)
                        )
                    if track["objectType"] != label:
                        raise Exception(
                            "Item %s: mismatching track labels, "
                            "track id %s: %s vs. %s"
                            % (item.id, track_id, track["objectType"], label)
                        )
                    # If there is a skip in track frames, add missing as outside
                    if frame_id != track["poses"][-1]["frame_id"] + 1:
                        last_key_pose = track["poses"][-1]
                        last_keyframe_id = last_key_pose["frame_id"]
                        last_key_pose["occlusion_kf"] = 1
                        for i in range(last_keyframe_id + 1, frame_id):
                            pose = deepcopy(last_key_pose)
                            pose["occlusion"] = OcclusionStates.OCCLUSION_UNSET
                            pose["truncation"] = TruncationStates.OUT_IMAGE
                            pose["frame_id"] = i
                            track["poses"].append(pose)
                occlusion = OcclusionStates.VISIBLE
                if "occlusion" in ann.attributes:
                    occlusion = OcclusionStates(ann.attributes["occlusion"].upper())
                elif "occluded" in ann.attributes:
                    if ann.attributes["occluded"]:
                        occlusion = OcclusionStates.PARTLY
                truncation = TruncationStates.IN_IMAGE
                if "truncation" in ann.attributes:
                    truncation = TruncationStates(ann.attributes["truncation"].upper())
                pose = {
                    "tx": ann.position[0],
                    "ty": ann.position[1],
                    "tz": ann.position[2],
                    "rx": ann.rotation[0],
                    "ry": ann.rotation[1],
                    "rz": ann.rotation[2],
                    "state": PoseStates.LABELED.value,
                    "occlusion": occlusion.value,
                    "occlusion_kf": int(ann.attributes.get("keyframe", False) is True),
                    "truncation": truncation.value,
                    "amt_occlusion": -1,
                    "amt_border_l": -1,
                    "amt_border_r": -1,
                    "amt_occlusion_kf": -1,
                    "amt_border_kf": -1,
                    "frame_id": frame_id,
                }
                if self._allow_attrs:
                    attributes = {}
                    for name, value in ann.attributes.items():
                        if name in self._builtin_attrs:
                            continue
                        if isinstance(value, bool):
                            value = "true" if value else "false"
                        attributes[name] = value
                    pose["attributes"] = attributes
                track["poses"].append(pose)
        self._write_name_mapping(name_mapping)
        return [e[1] for e in sorted(tracks.items(), key=lambda e: e[0])]
    def _write_name_mapping(self, name_mapping):
        with open(
            osp.join(self._save_dir, KittiRawPath.NAME_MAPPING_FILE), "w", encoding="utf-8"
        ) as f:
            f.writelines("%s %s\n" % (frame_id, name) for frame_id, name in name_mapping.items())
    def _get_label(self, label_id):
        if label_id is None:
            return ""
        label_cat = self._extractor.categories().get(AnnotationType.label, LabelCategories())
        return label_cat.items[label_id]
    def _write_item(self, item, index):
        if not self._reindex:
            index = cast(item.attributes.get("frame"), int, index)
        if self._save_media and item.media:
            self._save_point_cloud(item, subdir=KittiRawPath.PCD_DIR)
            images = sorted(item.media.extra_images, key=lambda img: img.path)
            for i, image in enumerate(images):
                if image.has_data:
                    image.save(
                        osp.join(
                            self._save_dir,
                            KittiRawPath.IMG_DIR_PREFIX + ("%02d" % i),
                            "data",
                            item.id + self._find_image_ext(image),
                        )
                    )
        elif self._save_media and not item.media:
            log.debug("Item '%s' has no image info", item.id)
        return index
[docs]    def apply(self):
        if self._extractor.media_type() and self._extractor.media_type() is not PointCloud:
            raise MediaTypeError("Media type is not a point cloud")
        os.makedirs(self._save_dir, exist_ok=True)
        if self._save_dataset_meta:
            self._save_meta_file(self._save_dir)
        if 1 < len(self._extractor.subsets()):
            log.warning(
                "Kitti RAW format supports only a single "
                "subset. Subset information will be ignored on export."
            )
        tracklets = self._create_tracklets(self._extractor)
        with open(osp.join(self._save_dir, KittiRawPath.ANNO_FILE), "w", encoding="utf-8") as f:
            writer = _XmlAnnotationWriter(f, tracklets)
            writer.write() 
[docs]    @classmethod
    def patch(cls, dataset, patch, save_dir, **kwargs):
        conv = cls(patch.as_dataset(dataset), save_dir=save_dir, **kwargs)
        conv.apply()
        pcd_dir = osp.abspath(osp.join(save_dir, KittiRawPath.PCD_DIR))
        for (item_id, subset), status in patch.updated_items.items():
            if status != ItemStatus.removed:
                item = patch.data.get(item_id, subset)
            else:
                item = DatasetItem(item_id, subset=subset)
            if not (status == ItemStatus.removed or not item.media):
                continue
            pcd_path = osp.join(pcd_dir, conv._make_pcd_filename(item))
            if osp.isfile(pcd_path):
                os.unlink(pcd_path)
            for d in os.listdir(save_dir):
                image_dir = osp.join(save_dir, d, "data", osp.dirname(item.id))
                if d.startswith(KittiRawPath.IMG_DIR_PREFIX) and osp.isdir(image_dir):
                    for p in find_images(image_dir):
                        if osp.splitext(osp.basename(p))[0] == osp.basename(item.id):
                            os.unlink(p)