Source code for datumaro.util

# Copyright (C) 2019-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

from functools import wraps
from inspect import isclass
from itertools import islice
from typing import Iterable, Tuple
import distutils.util

NOTSET = object()

str_to_bool = distutils.util.strtobool

[docs]def find(iterable, pred=lambda x: True, default=None): return next((x for x in iterable if pred(x)), default)
[docs]def cast(value, type_conv, default=None): if value is None: return default try: return type_conv(value) except Exception: return default
[docs]def to_snake_case(s): if not s: return '' name = [s[0].lower()] for idx, char in enumerate(s[1:]): idx = idx + 1 if char.isalpha() and char.isupper(): prev_char = s[idx - 1] if not (prev_char.isalpha() and prev_char.isupper()): # avoid "HTML" -> "h_t_m_l" name.append('_') name.append(char.lower()) else: name.append(char) return ''.join(name)
[docs]def pairs(iterable): a = iter(iterable) return zip(a, a)
[docs]def take_by(iterable, count): """ Returns elements from the input iterable by batches of N items. ('abcdefg', 3) -> ['a', 'b', 'c'], ['d', 'e', 'f'], ['g'] """ it = iter(iterable) while True: batch = list(islice(it, count)) if len(batch) == 0: break yield batch
[docs]def filter_dict(d, exclude_keys): return { k: v for k, v in d.items() if k not in exclude_keys }
[docs]def parse_str_enum_value(value, enum_class, default=NOTSET, unknown_member_error=None): if value is None and default is not NOTSET: value = default elif isinstance(value, str): try: value = enum_class[value] except KeyError: raise ValueError((unknown_member_error or "Unknown element of {cls} '{value}'. " "The only known are: {available}") \ .format( cls=enum_class.__name__, value=value, available=', '.join(e.name for e in enum_class) ) ) elif isinstance(value, enum_class): pass else: raise TypeError("Expected value type string or %s, but got %s" % \ (enum_class.__name__, type(value).__name__)) return value
[docs]def escape(s: str, escapes: Iterable[Tuple[str, str]]) -> str: """ 'escapes' is an iterable of (pattern, substitute) pairs """ for pattern, sub in escapes: s = s.replace(pattern, sub) return s
[docs]def unescape(s: str, escapes: Iterable[Tuple[str, str]]) -> str: """ 'escapes' is an iterable of (pattern, substitute) pairs """ for pattern, sub in escapes: s = s.replace(sub, pattern) return s
[docs]def is_method_redefined(method_name, base_class, target) -> bool: target_method = getattr(target, method_name, None) if not isclass(target) and target_method: target_method = getattr(target_method, '__func__', None) return getattr(base_class, method_name) != target_method
[docs]def optional_arg_decorator(fn): @wraps(fn) def wrapped_decorator(*args, **kwargs): if len(args) == 1 and callable(args[0]) and not kwargs: return fn(args[0], **kwargs) else: def real_decorator(decoratee): return fn(decoratee, *args, **kwargs) return real_decorator return wrapped_decorator