From: dirkf Date: Wed, 3 May 2023 11:40:09 +0000 (+0100) Subject: [utils] Align traverse_obj() with yt-dlp X-Git-Url: http://git.oshgnacknak.de/?a=commitdiff_plain;h=825a40744bf9aeb743452db24e43d3eb61feb6c2;p=youtube-dl [utils] Align traverse_obj() with yt-dlp Thanks Grub4k for these: * traverse `Iterable`s, from https://github.com/yt-dlp/yt-dlp/pull/6902, etc * traverse `set` key for transformations/filters, `re.Match` group names, from https://github.com/yt-dlp/yt-dlp/commit/776995bc109c5cd1aa56b684fada2ce718a386ec, etc * traverse `re.Match`es, from https://github.com/yt-dlp/yt-dlp/pull/5174 * always return list when branching, from https://github.com/yt-dlp/yt-dlp/pull/5170 --- diff --git a/test/test_utils.py b/test/test_utils.py index 2ee727caf..1b5d170fe 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,7 +20,7 @@ import xml.etree.ElementTree from youtube_dl.utils import ( age_restricted, args_to_str, - encode_base_n, + base_url, caesar, clean_html, clean_podcast_url, @@ -29,10 +29,12 @@ from youtube_dl.utils import ( detect_exe_version, determine_ext, dict_get, + encode_base_n, encode_compat_str, encodeFilename, escape_rfc3986, escape_url, + expand_path, extract_attributes, ExtractorError, find_xpath_attr, @@ -51,6 +53,7 @@ from youtube_dl.utils import ( js_to_json, LazyList, limit_length, + lowercase_escape, merge_dicts, mimetype2ext, month_by_name, @@ -66,17 +69,16 @@ from youtube_dl.utils import ( parse_resolution, parse_bitrate, pkcs1pad, - read_batch_urls, - sanitize_filename, - sanitize_path, - sanitize_url, - expand_path, prepend_extension, - replace_extension, + read_batch_urls, remove_start, remove_end, remove_quotes, + replace_extension, rot47, + sanitize_filename, + sanitize_path, + sanitize_url, shell_quote, smuggle_url, str_or_none, @@ -93,10 +95,8 @@ from youtube_dl.utils import ( unified_timestamp, unsmuggle_url, uppercase_escape, - lowercase_escape, url_basename, url_or_none, - base_url, urljoin, urlencode_postdata, urshift, @@ -1586,6 +1586,11 @@ Line 1 'dict': {}, } + # define a pukka Iterable + def iter_range(stop): + for from_ in range(stop): + yield from_ + # Test base functionality self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str', msg='allow tuple path') @@ -1602,13 +1607,13 @@ Line 1 # Test Ellipsis behavior self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis), (item for item in _TEST_DATA.values() if item not in (None, {})), - msg='`...` should give all non discarded values') + msg='`...` should give all non-discarded values') self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(), msg='`...` selection for dicts should select all values') self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')), ['https://www.example.com/0', 'https://www.example.com/1'], msg='nested `...` queries should work') - self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4), + self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), iter_range(4), msg='`...` query result should be flattened') self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)), msg='`...` should accept iterables') @@ -1618,7 +1623,7 @@ Line 1 [_TEST_DATA['urls']], msg='function as query key should perform a filter based on (key, value)') self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), set(('str',)), - msg='exceptions in the query function should be catched') + msg='exceptions in the query function should be caught') self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2], msg='function key should accept iterables') if __debug__: @@ -1706,7 +1711,7 @@ Line 1 self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {}, msg='remove empty values when dict key') self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis}, - msg='use `default` when dict key and `default`') + msg='use `default` when dict key and a default') self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {}, msg='remove empty values when nested dict key fails') self.assertEqual(traverse_obj(None, {0: 'fail'}), {}, @@ -1768,7 +1773,7 @@ Line 1 self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), 'str', msg='accept matching `expected_type` type') self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), - None, msg='reject non matching `expected_type` type') + None, msg='reject non-matching `expected_type` type') self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), '0', msg='transform type using type function') self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0), @@ -1780,7 +1785,7 @@ Line 1 self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), {0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values') self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, set((int_or_none,))), expected_type=int), - 1, msg='expected_type should not filter non final dict values') + 1, msg='expected_type should not filter non-final dict values') self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), {0: {0: 100}}, msg='expected_type should transform deep dict values') self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)), @@ -1838,7 +1843,7 @@ Line 1 self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)), _traverse_string=True), 'sr', msg='`slice` should result in string if `traverse_string`') - self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"), + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == 's'), _traverse_string=True), 'str', msg='function should result in string if `traverse_string`') self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 494f8341b..b77a7fb0e 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -4268,13 +4268,8 @@ def variadic(x, allowed_types=NO_DEFAULT): def dict_get(d, key_or_keys, default=None, skip_false_values=True): - if isinstance(key_or_keys, (list, tuple)): - for key in key_or_keys: - if key not in d or d[key] is None or skip_false_values and not d[key]: - continue - return d[key] - return default - return d.get(key_or_keys, default) + exp = (lambda x: x or None) if skip_false_values else IDENTITY + return traverse_obj(d, *variadic(key_or_keys), expected_type=exp, default=default) def try_call(*funcs, **kwargs):