Source code for datumaro.util

# Copyright (C) 2019-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from functools import wraps
from inspect import isclass
from itertools import islice
from typing import Any, Callable, Iterable, Tuple, TypeVar, Union

import attrs
import orjson

NOTSET = object()

str_to_bool = attrs.converters.to_bool

T = TypeVar("T")
U = TypeVar("U")


[docs]def find( iterable: Iterable[T], pred: Callable[[T], bool] = lambda x: True, *, default: U = None ) -> Union[T, U]: return next((x for x in iterable if pred(x)), default)
[docs]def cast(value, type_conv, default=None): if value is None: return default try: return type_conv(value) except Exception: return default
[docs]def to_snake_case(s: str) -> str: if not s: return "" name = [s[0].lower()] for idx, char in enumerate(s[1:]): idx = idx + 1 if char.isalpha() and char.isupper(): prev_char = s[idx - 1] if not (prev_char.isalpha() and prev_char.isupper()): # avoid "HTML" -> "h_t_m_l" name.append("_") name.append(char.lower()) else: name.append(char) return "".join(name)
[docs]def pairs(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]: a = iter(iterable) return zip(a, a)
[docs]def take_by(iterable, count): """ Returns elements from the input iterable by batches of N items. ('abcdefg', 3) -> ['a', 'b', 'c'], ['d', 'e', 'f'], ['g'] """ it = iter(iterable) while True: batch = list(islice(it, count)) if len(batch) == 0: break yield batch
[docs]def filter_dict(d, exclude_keys): return {k: v for k, v in d.items() if k not in exclude_keys}
[docs]def parse_str_enum_value(value, enum_class, default=NOTSET, unknown_member_error=None): if value is None and default is not NOTSET: value = default elif isinstance(value, str): try: value = enum_class[value] except KeyError: raise ValueError( ( unknown_member_error or "Unknown element of {cls} '{value}'. " "The only known are: {available}" ).format( cls=enum_class.__name__, value=value, available=", ".join(e.name for e in enum_class), ) ) elif isinstance(value, enum_class): pass else: raise TypeError( "Expected value type string or %s, but got %s" % (enum_class.__name__, type(value).__name__) ) return value
[docs]def escape(s: str, escapes: Iterable[Tuple[str, str]]) -> str: """ 'escapes' is an iterable of (pattern, substitute) pairs """ for pattern, sub in escapes: s = s.replace(pattern, sub) return s
[docs]def unescape(s: str, escapes: Iterable[Tuple[str, str]]) -> str: """ 'escapes' is an iterable of (pattern, substitute) pairs """ for pattern, sub in escapes: s = s.replace(sub, pattern) return s
[docs]def is_method_redefined(method_name, base_class, target) -> bool: target_method = getattr(target, method_name, None) if not isclass(target) and target_method: target_method = getattr(target_method, "__func__", None) return getattr(base_class, method_name) != target_method
[docs]def optional_arg_decorator(fn): @wraps(fn) def wrapped_decorator(*args, **kwargs): if len(args) == 1 and callable(args[0]) and not kwargs: return fn(args[0], **kwargs) else: def real_decorator(decoratee): return fn(decoratee, *args, **kwargs) return real_decorator return wrapped_decorator
[docs]def parse_json(data: Union[str, bytes]): return orjson.loads(data)
[docs]def parse_json_file(path: str): with open(path, "rb") as f: return parse_json(f.read())
[docs]def dump_json( data: Any, *, sort_keys: bool = False, allow_numpy: bool = True, indent: bool = False, append_newline: bool = False, ) -> bytes: flags = 0 if sort_keys: flags |= orjson.OPT_SORT_KEYS if allow_numpy: flags |= orjson.OPT_SERIALIZE_NUMPY if indent: flags |= orjson.OPT_INDENT_2 if append_newline: flags |= orjson.OPT_APPEND_NEWLINE return orjson.dumps(data, option=flags)
[docs]def dump_json_file( path: str, data: Any, *, sort_keys: bool = False, allow_numpy: bool = True, indent: bool = False, append_newline: bool = False, ) -> None: with open(path, "wb") as outfile: outfile.write( dump_json( data, sort_keys=sort_keys, allow_numpy=allow_numpy, indent=indent, append_newline=append_newline, ) )