[utils] Align traverse_obj() with yt-dlp
authordirkf <fieldhouse@gmx.net>
Wed, 3 May 2023 11:40:09 +0000 (12:40 +0100)
committerdirkf <fieldhouse@gmx.net>
Wed, 19 Jul 2023 21:14:50 +0000 (22:14 +0100)
Thanks Grub4k for these:
* traverse `Iterable`s, from https://github.com/yt-dlp/yt-dlp/pull/6902, etc
* traverse `set` key for transformations/filters, `re.Match` group names, from
  https://github.com/yt-dlp/yt-dlp/commit/776995bc109c5cd1aa56b684fada2ce718a386ec, etc
* traverse `re.Match`es, from https://github.com/yt-dlp/yt-dlp/pull/5174
* always return list when branching, from https://github.com/yt-dlp/yt-dlp/pull/5170

test/test_utils.py
youtube_dl/utils.py

index 2ee727caf2b5a1717f349dbb6a4997c314b3bf0a..1b5d170fe5287a62864c4e931556316371955a3e 100644 (file)
@@ -20,7 +20,7 @@ import xml.etree.ElementTree
 from youtube_dl.utils import (
     age_restricted,
     args_to_str,
-    encode_base_n,
+    base_url,
     caesar,
     clean_html,
     clean_podcast_url,
@@ -29,10 +29,12 @@ from youtube_dl.utils import (
     detect_exe_version,
     determine_ext,
     dict_get,
+    encode_base_n,
     encode_compat_str,
     encodeFilename,
     escape_rfc3986,
     escape_url,
+    expand_path,
     extract_attributes,
     ExtractorError,
     find_xpath_attr,
@@ -51,6 +53,7 @@ from youtube_dl.utils import (
     js_to_json,
     LazyList,
     limit_length,
+    lowercase_escape,
     merge_dicts,
     mimetype2ext,
     month_by_name,
@@ -66,17 +69,16 @@ from youtube_dl.utils import (
     parse_resolution,
     parse_bitrate,
     pkcs1pad,
-    read_batch_urls,
-    sanitize_filename,
-    sanitize_path,
-    sanitize_url,
-    expand_path,
     prepend_extension,
-    replace_extension,
+    read_batch_urls,
     remove_start,
     remove_end,
     remove_quotes,
+    replace_extension,
     rot47,
+    sanitize_filename,
+    sanitize_path,
+    sanitize_url,
     shell_quote,
     smuggle_url,
     str_or_none,
@@ -93,10 +95,8 @@ from youtube_dl.utils import (
     unified_timestamp,
     unsmuggle_url,
     uppercase_escape,
-    lowercase_escape,
     url_basename,
     url_or_none,
-    base_url,
     urljoin,
     urlencode_postdata,
     urshift,
@@ -1586,6 +1586,11 @@ Line 1
             'dict': {},
         }
 
+        # define a pukka Iterable
+        def iter_range(stop):
+            for from_ in range(stop):
+                yield from_
+
         # Test base functionality
         self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str',
                          msg='allow tuple path')
@@ -1602,13 +1607,13 @@ Line 1
         # Test Ellipsis behavior
         self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis),
                               (item for item in _TEST_DATA.values() if item not in (None, {})),
-                              msg='`...` should give all non discarded values')
+                              msg='`...` should give all non-discarded values')
         self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(),
                               msg='`...` selection for dicts should select all values')
         self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')),
                          ['https://www.example.com/0', 'https://www.example.com/1'],
                          msg='nested `...` queries should work')
-        self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4),
+        self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), iter_range(4),
                               msg='`...` query result should be flattened')
         self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)),
                          msg='`...` should accept iterables')
@@ -1618,7 +1623,7 @@ Line 1
                          [_TEST_DATA['urls']],
                          msg='function as query key should perform a filter based on (key, value)')
         self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), set(('str',)),
-                              msg='exceptions in the query function should be catched')
+                              msg='exceptions in the query function should be caught')
         self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
                          msg='function key should accept iterables')
         if __debug__:
@@ -1706,7 +1711,7 @@ Line 1
         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
                          msg='remove empty values when dict key')
         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis},
-                         msg='use `default` when dict key and `default`')
+                         msg='use `default` when dict key and a default')
         self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
                          msg='remove empty values when nested dict key fails')
         self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
@@ -1768,7 +1773,7 @@ Line 1
         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),
                          'str', msg='accept matching `expected_type` type')
         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int),
-                         None, msg='reject non matching `expected_type` type')
+                         None, msg='reject non-matching `expected_type` type')
         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),
                          '0', msg='transform type using type function')
         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0),
@@ -1780,7 +1785,7 @@ Line 1
         self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),
                          {0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')
         self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, set((int_or_none,))), expected_type=int),
-                         1, msg='expected_type should not filter non final dict values')
+                         1, msg='expected_type should not filter non-final dict values')
         self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),
                          {0: {0: 100}}, msg='expected_type should transform deep dict values')
         self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)),
@@ -1838,7 +1843,7 @@ Line 1
         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
                                       _traverse_string=True), 'sr',
                          msg='`slice` should result in string if `traverse_string`')
-        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"),
+        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == 's'),
                                       _traverse_string=True), 'str',
                          msg='function should result in string if `traverse_string`')
         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
index 494f8341b48375128f16dd0e61cc54757fac6a33..b77a7fb0e0f49851f4757452fb2e367a2941c750 100644 (file)
@@ -4268,13 +4268,8 @@ def variadic(x, allowed_types=NO_DEFAULT):
 
 
 def dict_get(d, key_or_keys, default=None, skip_false_values=True):
-    if isinstance(key_or_keys, (list, tuple)):
-        for key in key_or_keys:
-            if key not in d or d[key] is None or skip_false_values and not d[key]:
-                continue
-            return d[key]
-        return default
-    return d.get(key_or_keys, default)
+    exp = (lambda x: x or None) if skip_false_values else IDENTITY
+    return traverse_obj(d, *variadic(key_or_keys), expected_type=exp, default=default)
 
 
 def try_call(*funcs, **kwargs):