# Copyright (C) 2019-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT
from functools import wraps
from inspect import isclass
from itertools import islice
from typing import Iterable, Tuple
import distutils.util
NOTSET = object()
str_to_bool = distutils.util.strtobool
[docs]def find(iterable, pred=lambda x: True, default=None):
return next((x for x in iterable if pred(x)), default)
[docs]def cast(value, type_conv, default=None):
if value is None:
return default
try:
return type_conv(value)
except Exception:
return default
[docs]def to_snake_case(s):
if not s:
return ''
name = [s[0].lower()]
for idx, char in enumerate(s[1:]):
idx = idx + 1
if char.isalpha() and char.isupper():
prev_char = s[idx - 1]
if not (prev_char.isalpha() and prev_char.isupper()):
# avoid "HTML" -> "h_t_m_l"
name.append('_')
name.append(char.lower())
else:
name.append(char)
return ''.join(name)
[docs]def pairs(iterable):
a = iter(iterable)
return zip(a, a)
[docs]def take_by(iterable, count):
"""
Returns elements from the input iterable by batches of N items.
('abcdefg', 3) -> ['a', 'b', 'c'], ['d', 'e', 'f'], ['g']
"""
it = iter(iterable)
while True:
batch = list(islice(it, count))
if len(batch) == 0:
break
yield batch
[docs]def filter_dict(d, exclude_keys):
return { k: v for k, v in d.items() if k not in exclude_keys }
[docs]def parse_str_enum_value(value, enum_class, default=NOTSET,
unknown_member_error=None):
if value is None and default is not NOTSET:
value = default
elif isinstance(value, str):
try:
value = enum_class[value]
except KeyError:
raise ValueError((unknown_member_error or
"Unknown element of {cls} '{value}'. "
"The only known are: {available}") \
.format(
cls=enum_class.__name__,
value=value,
available=', '.join(e.name for e in enum_class)
)
)
elif isinstance(value, enum_class):
pass
else:
raise TypeError("Expected value type string or %s, but got %s" % \
(enum_class.__name__, type(value).__name__))
return value
[docs]def escape(s: str, escapes: Iterable[Tuple[str, str]]) -> str:
"""
'escapes' is an iterable of (pattern, substitute) pairs
"""
for pattern, sub in escapes:
s = s.replace(pattern, sub)
return s
[docs]def unescape(s: str, escapes: Iterable[Tuple[str, str]]) -> str:
"""
'escapes' is an iterable of (pattern, substitute) pairs
"""
for pattern, sub in escapes:
s = s.replace(sub, pattern)
return s
[docs]def is_method_redefined(method_name, base_class, target) -> bool:
target_method = getattr(target, method_name, None)
if not isclass(target) and target_method:
target_method = getattr(target_method, '__func__', None)
return getattr(base_class, method_name) != target_method
[docs]def optional_arg_decorator(fn):
@wraps(fn)
def wrapped_decorator(*args, **kwargs):
if len(args) == 1 and callable(args[0]) and not kwargs:
return fn(args[0], **kwargs)
else:
def real_decorator(decoratee):
return fn(decoratee, *args, **kwargs)
return real_decorator
return wrapped_decorator