compat_chr,
compat_etree_fromstring,
compat_getenv,
+ compat_http_cookies,
compat_os_name,
compat_setenv,
compat_str,
class TestUtil(unittest.TestCase):
- # yt-dlp shim
- def assertCountEqual(self, expected, got, msg='count should be the same'):
- return self.assertEqual(len(tuple(expected)), len(tuple(got)), msg=msg)
-
def test_timeconvert(self):
self.assertTrue(timeconvert('') is None)
self.assertTrue(timeconvert('bougrg') is None)
self.assertRaises(
ValueError, multipart_encode, {b'field': b'value'}, boundary='value')
- def test_dict_get(self):
- FALSE_VALUES = {
- 'none': None,
- 'false': False,
- 'zero': 0,
- 'empty_string': '',
- 'empty_list': [],
- }
- d = FALSE_VALUES.copy()
- d['a'] = 42
- self.assertEqual(dict_get(d, 'a'), 42)
- self.assertEqual(dict_get(d, 'b'), None)
- self.assertEqual(dict_get(d, 'b', 42), 42)
- self.assertEqual(dict_get(d, ('a', )), 42)
- self.assertEqual(dict_get(d, ('b', 'a', )), 42)
- self.assertEqual(dict_get(d, ('b', 'c', 'a', 'd', )), 42)
- self.assertEqual(dict_get(d, ('b', 'c', )), None)
- self.assertEqual(dict_get(d, ('b', 'c', ), 42), 42)
- for key, false_value in FALSE_VALUES.items():
- self.assertEqual(dict_get(d, ('b', 'c', key, )), None)
- self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value)
-
def test_merge_dicts(self):
self.assertEqual(merge_dicts({'a': 1}, {'b': 2}), {'a': 1, 'b': 2})
self.assertEqual(merge_dicts({'a': 1}, {'a': 2}), {'a': 1})
self.assertEqual(variadic('spam', allowed_types=dict), 'spam')
self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam')
+ def test_join_nonempty(self):
+ self.assertEqual(join_nonempty('a', 'b'), 'a-b')
+ self.assertEqual(join_nonempty(
+ 'a', 'b', 'c', 'd',
+ from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d')
+
+
+class TestTraversal(unittest.TestCase):
+ str = compat_str
+ _TEST_DATA = {
+ 100: 100,
+ 1.2: 1.2,
+ 'str': 'str',
+ 'None': None,
+ '...': Ellipsis,
+ 'urls': [
+ {'index': 0, 'url': 'https://www.example.com/0'},
+ {'index': 1, 'url': 'https://www.example.com/1'},
+ ],
+ 'data': (
+ {'index': 2},
+ {'index': 3},
+ ),
+ 'dict': {},
+ }
+
+ # yt-dlp shim
+ def assertCountEqual(self, expected, got, msg='count should be the same'):
+ return self.assertEqual(len(tuple(expected)), len(tuple(got)), msg=msg)
+
+ def assertMaybeCountEqual(self, *args, **kwargs):
+ if sys.version_info < (3, 7):
+ # random dict order
+ return self.assertCountEqual(*args, **kwargs)
+ else:
+ return self.assertEqual(*args, **kwargs)
+
def test_traverse_obj(self):
- str = compat_str
- _TEST_DATA = {
- 100: 100,
- 1.2: 1.2,
- 'str': 'str',
- 'None': None,
- '...': Ellipsis,
- 'urls': [
- {'index': 0, 'url': 'https://www.example.com/0'},
- {'index': 1, 'url': 'https://www.example.com/1'},
- ],
- 'data': (
- {'index': 2},
- {'index': 3},
- ),
- 'dict': {},
- }
+ str = self.str
+ _TEST_DATA = self._TEST_DATA
# define a pukka Iterable
def iter_range(stop):
# Test set as key (transformation/type, like `expected_type`)
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper), )), ['STR'],
msg='Function in set should be a transformation')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('fail', T(lambda _: 'const'))), 'const',
+ msg='Function in set should always be called')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str))), ['str'],
msg='Type in set should be a type filter')
+ self.assertMaybeCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str, int))), [100, 'str'],
+ msg='Multiple types in set should be a type filter')
self.assertEqual(traverse_obj(_TEST_DATA, T(dict)), _TEST_DATA,
msg='A single set should be wrapped into a path')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper))), ['STR'],
msg='Transformation function should not raise')
- self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str_or_none))),
- [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
- msg='Function in set should be a transformation')
+ self.assertMaybeCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str_or_none))),
+ [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
+ msg='Function in set should be a transformation')
if __debug__:
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, set())
self.assertEqual(traverse_obj({}, (0, slice(1)), _traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
- # Test is_user_input behavior
- _IS_USER_INPUT_DATA = {'range8': list(range(8))}
- self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'),
- _is_user_input=True), 3,
- msg='allow for string indexing if `is_user_input`')
- self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'),
- _is_user_input=True), tuple(range(8))[3:],
- msg='allow for string slice if `is_user_input`')
- self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'),
- _is_user_input=True), tuple(range(8))[:4:2],
- msg='allow step in string slice if `is_user_input`')
- self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'),
- _is_user_input=True), range(8),
- msg='`:` should be treated as `...` if `is_user_input`')
- with self.assertRaises(TypeError, msg='too many params should result in error'):
- traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), _is_user_input=True)
-
# Test re.Match as input obj
mobj = re.match(r'^0(12)(?P<group>3)(4)?$', '0123')
self.assertEqual(traverse_obj(mobj, Ellipsis), [x for x in mobj.groups() if x is not None],
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
msg='function on a `re.Match` should give group name as well')
+ # Test xml.etree.ElementTree.Element as input obj
+ etree = compat_etree_fromstring('''<?xml version="1.0"?>
+ <data>
+ <country name="Liechtenstein">
+ <rank>1</rank>
+ <year>2008</year>
+ <gdppc>141100</gdppc>
+ <neighbor name="Austria" direction="E"/>
+ <neighbor name="Switzerland" direction="W"/>
+ </country>
+ <country name="Singapore">
+ <rank>4</rank>
+ <year>2011</year>
+ <gdppc>59900</gdppc>
+ <neighbor name="Malaysia" direction="N"/>
+ </country>
+ <country name="Panama">
+ <rank>68</rank>
+ <year>2011</year>
+ <gdppc>13600</gdppc>
+ <neighbor name="Costa Rica" direction="W"/>
+ <neighbor name="Colombia" direction="E"/>
+ </country>
+ </data>''')
+ self.assertEqual(traverse_obj(etree, ''), etree,
+ msg='empty str key should return the element itself')
+ self.assertEqual(traverse_obj(etree, 'country'), list(etree),
+ msg='str key should return all children with that tag name')
+ self.assertEqual(traverse_obj(etree, Ellipsis), list(etree),
+ msg='`...` as key should return all children')
+ self.assertEqual(traverse_obj(etree, lambda _, x: x[0].text == '4'), [etree[1]],
+ msg='function as key should get element as value')
+ self.assertEqual(traverse_obj(etree, lambda i, _: i == 1), [etree[1]],
+ msg='function as key should get index as key')
+ self.assertEqual(traverse_obj(etree, 0), etree[0],
+ msg='int key should return the nth child')
+ self.assertEqual(traverse_obj(etree, './/neighbor/@name'),
+ ['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia'],
+ msg='`@<attribute>` at end of path should give that attribute')
+ self.assertEqual(traverse_obj(etree, '//neighbor/@fail'), [None, None, None, None, None],
+ msg='`@<nonexistent>` at end of path should give `None`')
+ self.assertEqual(traverse_obj(etree, ('//neighbor/@', 2)), {'name': 'Malaysia', 'direction': 'N'},
+ msg='`@` should give the full attribute dict')
+ self.assertEqual(traverse_obj(etree, '//year/text()'), ['2008', '2011', '2011'],
+ msg='`text()` at end of path should give the inner text')
+ self.assertEqual(traverse_obj(etree, '//*[@direction]/@direction'), ['E', 'W', 'N', 'W', 'E'],
+ msg='full python xpath features should be supported')
+ self.assertEqual(traverse_obj(etree, (0, '@name')), 'Liechtenstein',
+ msg='special transformations should act on current element')
+ self.assertEqual(traverse_obj(etree, ('country', 0, Ellipsis, 'text()', T(int_or_none))), [1, 2008, 141100],
+ msg='special transformations should act on current element')
+
+ def test_traversal_unbranching(self):
+ # str = self.str
+ _TEST_DATA = self._TEST_DATA
+
+ self.assertEqual(traverse_obj(_TEST_DATA, [(100, 1.2), all]), [100, 1.2],
+ msg='`all` should give all results as list')
+ self.assertEqual(traverse_obj(_TEST_DATA, [(100, 1.2), any]), 100,
+ msg='`any` should give the first result')
+ self.assertEqual(traverse_obj(_TEST_DATA, [100, all]), [100],
+ msg='`all` should give list if non branching')
+ self.assertEqual(traverse_obj(_TEST_DATA, [100, any]), 100,
+ msg='`any` should give single item if non branching')
+ self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 100), all]), [100],
+ msg='`all` should filter `None` and empty dict')
+ self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 100), any]), 100,
+ msg='`any` should filter `None` and empty dict')
+ self.assertEqual(traverse_obj(_TEST_DATA, [{
+ 'all': [('dict', 'None', 100, 1.2), all],
+ 'any': [('dict', 'None', 100, 1.2), any],
+ }]), {'all': [100, 1.2], 'any': 100},
+ msg='`all`/`any` should apply to each dict path separately')
+ self.assertEqual(traverse_obj(_TEST_DATA, [{
+ 'all': [('dict', 'None', 100, 1.2), all],
+ 'any': [('dict', 'None', 100, 1.2), any],
+ }], get_all=False), {'all': [100, 1.2], 'any': 100},
+ msg='`all`/`any` should apply to dict regardless of `get_all`')
+ self.assertIs(traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, T(float)]), None,
+ msg='`all` should reset branching status')
+ self.assertIs(traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), any, T(float)]), None,
+ msg='`any` should reset branching status')
+ self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, Ellipsis, T(float)]), [1.2],
+ msg='`all` should allow further branching')
+ self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 'urls', 'data'), any, Ellipsis, 'index']), [0, 1],
+ msg='`any` should allow further branching')
+
+ def test_traversal_morsel(self):
+ values = {
+ 'expires': 'a',
+ 'path': 'b',
+ 'comment': 'c',
+ 'domain': 'd',
+ 'max-age': 'e',
+ 'secure': 'f',
+ 'httponly': 'g',
+ 'version': 'h',
+ 'samesite': 'i',
+ }
+ # SameSite added in Py3.8, breaks .update for 3.5-3.7
+ if sys.version_info < (3, 8):
+ del values['samesite']
+ morsel = compat_http_cookies.Morsel()
+ morsel.set(str('item_key'), 'item_value', 'coded_value')
+ morsel.update(values)
+ values['key'] = str('item_key')
+ values['value'] = 'item_value'
+ values = dict((str(k), v) for k, v in values.items())
+ # make test pass even without ordered dict
+ value_set = set(values.values())
+
+ for key, value in values.items():
+ self.assertEqual(traverse_obj(morsel, key), value,
+ msg='Morsel should provide access to all values')
+ self.assertEqual(set(traverse_obj(morsel, Ellipsis)), value_set,
+ msg='`...` should yield all values')
+ self.assertEqual(set(traverse_obj(morsel, lambda k, v: True)), value_set,
+ msg='function key should yield all values')
+ self.assertIs(traverse_obj(morsel, [(None,), any]), morsel,
+ msg='Morsel should not be implicitly changed to dict on usage')
+
def test_get_first(self):
self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')
- def test_join_nonempty(self):
- self.assertEqual(join_nonempty('a', 'b'), 'a-b')
- self.assertEqual(join_nonempty(
- 'a', 'b', 'c', 'd',
- from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d')
+ def test_dict_get(self):
+ FALSE_VALUES = {
+ 'none': None,
+ 'false': False,
+ 'zero': 0,
+ 'empty_string': '',
+ 'empty_list': [],
+ }
+ d = FALSE_VALUES.copy()
+ d['a'] = 42
+ self.assertEqual(dict_get(d, 'a'), 42)
+ self.assertEqual(dict_get(d, 'b'), None)
+ self.assertEqual(dict_get(d, 'b', 42), 42)
+ self.assertEqual(dict_get(d, ('a', )), 42)
+ self.assertEqual(dict_get(d, ('b', 'a', )), 42)
+ self.assertEqual(dict_get(d, ('b', 'c', 'a', 'd', )), 42)
+ self.assertEqual(dict_get(d, ('b', 'c', )), None)
+ self.assertEqual(dict_get(d, ('b', 'c', ), 42), 42)
+ for key, false_value in FALSE_VALUES.items():
+ self.assertEqual(dict_get(d, ('b', 'c', key, )), None)
+ self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value)
if __name__ == '__main__':
compat_cookiejar,
compat_ctypes_WINFUNCTYPE,
compat_datetime_timedelta_total_seconds,
+ compat_etree_Element,
compat_etree_fromstring,
+ compat_etree_iterfind,
compat_expanduser,
compat_html_entities,
compat_html_entities_html5,
compat_http_client,
+ compat_http_cookies,
compat_integer_types,
compat_kwargs,
compat_ncompress as ncompress,
def traverse_obj(obj, *paths, **kwargs):
"""
- Safely traverse nested `dict`s and `Iterable`s
+ Safely traverse nested `dict`s and `Iterable`s, etc
>>> obj = [{}, {"key": "value"}]
>>> traverse_obj(obj, (1, "key"))
- "value"
+ 'value'
Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found.
- Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
+ Supported values for traversal are `Mapping`, `Iterable`, `re.Match`, `xml.etree.ElementTree`
+ (xpath) and `http.cookies.Morsel`.
Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
The keys in the path can be one of:
- `None`: Return the current object.
- `set`: Requires the only item in the set to be a type or function,
- like `{type}`/`{func}`. If a `type`, returns only values
- of this type. If a function, returns `func(obj)`.
+ like `{type}`/`{type, type, ...}`/`{func}`. If one or more `type`s,
+ return only values that have one of the types. If a function,
+ return `func(obj)`.
- `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
- `slice`: Branch out and return all values in `obj[key]`.
- `Ellipsis`: Branch out and return a list of all values.
For `Iterable`s, `key` is the enumeration count of the value.
For `re.Match`es, `key` is the group number (0 = full match)
as well as additionally any group names, if given.
- - `dict` Transform the current object and return a matching dict.
+ - `dict`: Transform the current object and return a matching dict.
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
+ - `any`-builtin: Take the first matching object and return it, resetting branching.
+ - `all`-builtin: Take all matching objects and return them as a list, resetting branching.
`tuple`, `list`, and `dict` all support nested paths and branches.
@param get_all If `False`, return the first matching result, otherwise all matching ones.
@param casesense If `False`, consider string dictionary keys as case insensitive.
- The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API
+ The following is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
- @param _is_user_input Whether the keys are generated from user input.
- If `True` strings get converted to `int`/`slice` if needed.
@param _traverse_string Whether to traverse into objects as strings.
If `True`, any non-compatible object will first be
converted into a string and then traversed into.
expected_type = kwargs.get('expected_type')
get_all = kwargs.get('get_all', True)
casesense = kwargs.get('casesense', True)
- _is_user_input = kwargs.get('_is_user_input', False)
_traverse_string = kwargs.get('_traverse_string', False)
# instant compat
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
def lookup_or_none(v, k, getter=None):
- try:
+ with compat_contextlib_suppress(LookupError):
return getter(v, k) if getter else v[k]
- except IndexError:
- return None
def from_iterable(iterables):
# chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
result = obj
elif isinstance(key, set):
- assert len(key) == 1, 'Set should only be used to wrap a single item'
- item = next(iter(key))
- if isinstance(item, type):
- result = obj if isinstance(obj, item) else None
+ assert len(key) >= 1, 'At least one item is required in a `set` key'
+ if all(isinstance(item, type) for item in key):
+ result = obj if isinstance(obj, tuple(key)) else None
else:
- result = try_call(item, args=(obj,))
+ item = next(iter(key))
+ assert len(key) == 1, 'Multiple items in a `set` key must all be types'
+ result = try_call(item, args=(obj,)) if not isinstance(item, type) else None
elif isinstance(key, (list, tuple)):
branching = True
elif key is Ellipsis:
branching = True
+ if isinstance(obj, compat_http_cookies.Morsel):
+ obj = dict(obj, key=obj.key, value=obj.value)
if isinstance(obj, compat_collections_abc.Mapping):
result = obj.values()
- elif is_iterable_like(obj):
+ elif is_iterable_like(obj, (compat_collections_abc.Iterable, compat_etree_Element)):
result = obj
elif isinstance(obj, compat_re_Match):
result = obj.groups()
elif callable(key):
branching = True
+ if isinstance(obj, compat_http_cookies.Morsel):
+ obj = dict(obj, key=obj.key, value=obj.value)
if isinstance(obj, compat_collections_abc.Mapping):
iter_obj = obj.items()
- elif is_iterable_like(obj):
+ elif is_iterable_like(obj, (compat_collections_abc.Iterable, compat_etree_Element)):
iter_obj = enumerate(obj)
elif isinstance(obj, compat_re_Match):
iter_obj = itertools.chain(
if v is not None or default is not NO_DEFAULT) or None
elif isinstance(obj, compat_collections_abc.Mapping):
+ if isinstance(obj, compat_http_cookies.Morsel):
+ obj = dict(obj, key=obj.key, value=obj.value)
result = (try_call(obj.get, args=(key,))
if casesense or try_call(obj.__contains__, args=(key,))
else next((v for k, v in obj.items() if casefold(k) == key), None))
else:
result = None
if isinstance(key, (int, slice)):
- if is_iterable_like(obj, compat_collections_abc.Sequence):
+ if is_iterable_like(obj, (compat_collections_abc.Sequence, compat_etree_Element)):
branching = isinstance(key, slice)
result = lookup_or_none(obj, key)
elif _traverse_string:
result = lookup_or_none(str(obj), key)
+ elif isinstance(obj, compat_etree_Element) and isinstance(key, str):
+ xpath, _, special = key.rpartition('/')
+ if not special.startswith('@') and not special.endswith('()'):
+ xpath = key
+ special = None
+
+ # Allow abbreviations of relative paths, absolute paths error
+ if xpath.startswith('/'):
+ xpath = '.' + xpath
+ elif xpath and not xpath.startswith('./'):
+ xpath = './' + xpath
+
+ def apply_specials(element):
+ if special is None:
+ return element
+ if special == '@':
+ return element.attrib
+ if special.startswith('@'):
+ return try_call(element.attrib.get, args=(special[1:],))
+ if special == 'text()':
+ return element.text
+ raise SyntaxError('apply_specials is missing case for {0!r}'.format(special))
+
+ if xpath:
+ result = list(map(apply_specials, compat_etree_iterfind(obj, xpath)))
+ else:
+ result = apply_specials(obj)
+
return branching, result if branching else (result,)
def lazy_last(iterable):
key = None
for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
- if _is_user_input and isinstance(key, str):
- if key == ':':
- key = Ellipsis
- elif ':' in key:
- key = slice(*map(int_or_none, key.split(':')))
- elif int_or_none(key) is not None:
- key = int(key)
-
if not casesense and isinstance(key, str):
key = compat_casefold(key)
+ if key in (any, all):
+ has_branched = False
+ filtered_objs = (obj for obj in objs if obj not in (None, {}))
+ if key is any:
+ objs = (next(filtered_objs, None),)
+ else:
+ objs = (list(filtered_objs),)
+ continue
+
if __debug__ and callable(key):
# Verify function signature
_try_bind_args(key, None, None)
return None if default is NO_DEFAULT else default
-def T(x):
- """ For use in yt-dl instead of {type} or set((type,)) """
- return set((x,))
+def T(*x):
+ """ For use in yt-dl instead of {type, ...} or set((type, ...)) """
+ return set(x)
def get_first(obj, keys, **kwargs):