[utils] Revise `isinstance()` tests (especially for str/unicode/bytes) to complete...
authordirkf <fieldhouse@gmx.net>
Sun, 30 Jul 2023 20:47:48 +0000 (21:47 +0100)
committerdirkf <fieldhouse@gmx.net>
Tue, 1 Aug 2023 00:05:09 +0000 (01:05 +0100)
youtube_dl/compat.py
youtube_dl/utils.py

index 54ad64674fa0b4b86b5f067940ed71c802665c07..3c526a78dc521ea0e97bbd2eb02103026020a8b6 100644 (file)
@@ -36,7 +36,7 @@ try:
     )
 except NameError:
     compat_str, compat_basestring, compat_chr = (
-        str, str, chr
+        str, (str, bytes), chr
     )
 
 # casefold
index 1da5a7a38b9ca7bf7a75aee2e8f53b8339acf42c..94b339b1df89290a9a7acc7afed95e77ddd5260a 100644 (file)
@@ -1826,11 +1826,11 @@ def write_json_file(obj, fn):
     if sys.version_info < (3, 0) and sys.platform != 'win32':
         encoding = get_filesystem_encoding()
         # os.path.basename returns a bytes object, but NamedTemporaryFile
-        # will fail if the filename contains non ascii characters unless we
+        # will fail if the filename contains non-ascii characters unless we
         # use a unicode object
-        path_basename = lambda f: os.path.basename(fn).decode(encoding)
+        path_basename = lambda f: os.path.basename(f).decode(encoding)
         # the same for os.path.dirname
-        path_dirname = lambda f: os.path.dirname(fn).decode(encoding)
+        path_dirname = lambda f: os.path.dirname(f).decode(encoding)
     else:
         path_basename = os.path.basename
         path_dirname = os.path.dirname
@@ -1894,10 +1894,10 @@ else:
                 return f
         return None
 
+
 # On python2.6 the xml.etree.ElementTree.Element methods don't support
 # the namespace parameter
 
-
 def xpath_with_ns(path, ns_map):
     components = [c.split(':') for c in path.split('/')]
     replaced = []
@@ -1914,7 +1914,7 @@ def xpath_element(node, xpath, name=None, fatal=False, default=NO_DEFAULT):
     def _find_xpath(xpath):
         return node.find(compat_xpath(xpath))
 
-    if isinstance(xpath, (str, compat_str)):
+    if isinstance(xpath, compat_basestring):
         n = _find_xpath(xpath)
     else:
         for xp in xpath:
@@ -2262,39 +2262,32 @@ def get_subprocess_encoding():
     return encoding
 
 
-def encodeFilename(s, for_subprocess=False):
-    """
-    @param s The name of the file
-    """
-
-    assert type(s) == compat_str
-
-    # Python 3 has a Unicode API
-    if sys.version_info >= (3, 0):
-        return s
-
-    # Pass '' directly to use Unicode APIs on Windows 2000 and up
-    # (Detecting Windows NT 4 is tricky because 'major >= 4' would
-    # match Windows 9x series as well. Besides, NT 4 is obsolete.)
-    if not for_subprocess and sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
-        return s
-
-    # Jython assumes filenames are Unicode strings though reported as Python 2.x compatible
-    if sys.platform.startswith('java'):
-        return s
+# Jython assumes filenames are Unicode strings though reported as Python 2.x compatible
+if sys.version_info < (3, 0) and not sys.platform.startswith('java'):
 
-    return s.encode(get_subprocess_encoding(), 'ignore')
+    def encodeFilename(s, for_subprocess=False):
+        """
+        @param s The name of the file
+        """
 
+        # Pass '' directly to use Unicode APIs on Windows 2000 and up
+        # (Detecting Windows NT 4 is tricky because 'major >= 4' would
+        # match Windows 9x series as well. Besides, NT 4 is obsolete.)
+        if (not for_subprocess
+                and sys.platform == 'win32'
+                and sys.getwindowsversion()[0] >= 5
+                and isinstance(s, compat_str)):
+            return s
 
-def decodeFilename(b, for_subprocess=False):
+        return _encode_compat_str(s, get_subprocess_encoding(), 'ignore')
 
-    if sys.version_info >= (3, 0):
-        return b
+    def decodeFilename(b, for_subprocess=False):
+        return _decode_compat_str(b, get_subprocess_encoding(), 'ignore')
 
-    if not isinstance(b, bytes):
-        return b
+else:
 
-    return b.decode(get_subprocess_encoding(), 'ignore')
+    # Python 3 has a Unicode API
+    encodeFilename = decodeFilename = lambda *s, **k: s[0]
 
 
 def encodeArgument(s):
@@ -2313,11 +2306,7 @@ def decodeArgument(b):
 def decodeOption(optval):
     if optval is None:
         return optval
-    if isinstance(optval, bytes):
-        optval = optval.decode(preferredencoding())
-
-    assert isinstance(optval, compat_str)
-    return optval
+    return _decode_compat_str(optval)
 
 
 def formatSeconds(secs):
@@ -2363,7 +2352,7 @@ def make_HTTPS_handler(params, **kwargs):
 
     if sys.version_info < (3, 2):
         return YoutubeDLHTTPSHandler(params, **kwargs)
-    else:  # Python < 3.4
+    else:  # Python3 < 3.4
         context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
         context.verify_mode = (ssl.CERT_NONE
                                if opts_no_check_certificate
@@ -2818,8 +2807,7 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
                 location_escaped = escape_url(location_fixed)
                 if location != location_escaped:
                     del resp.headers['Location']
-                    # if sys.version_info < (3, 0):
-                    if not isinstance(location_escaped, str):
+                    if not isinstance(location_escaped, str):  # Py 2 case
                         location_escaped = location_escaped.encode('utf-8')
                     resp.headers['Location'] = location_escaped
         return resp
@@ -3086,8 +3074,7 @@ class YoutubeDLRedirectHandler(compat_urllib_request.HTTPRedirectHandler):
         # On python 2 urlh.geturl() may sometimes return redirect URL
         # as a byte string instead of unicode. This workaround forces
         # it to return unicode.
-        if sys.version_info[0] < 3:
-            newurl = compat_str(newurl)
+        newurl = _decode_compat_str(newurl)
 
         # Be conciliant with URIs containing a space.  This is mainly
         # redundant with the more complete encoding done in http_error_302(),
@@ -3333,11 +3320,7 @@ class DateRange(object):
 def platform_name():
     """ Returns the platform name as a compat_str """
     res = platform.platform()
-    if isinstance(res, bytes):
-        res = res.decode(preferredencoding())
-
-    assert isinstance(res, compat_str)
-    return res
+    return _decode_compat_str(res)
 
 
 def _windows_write_string(s, out):
@@ -3567,9 +3550,8 @@ def shell_quote(args):
     quoted_args = []
     encoding = get_filesystem_encoding()
     for a in args:
-        if isinstance(a, bytes):
-            # We may get a filename encoded with 'encodeFilename'
-            a = a.decode(encoding)
+        # We may get a filename encoded with 'encodeFilename'
+        a = _decode_compat_str(a, encoding)
         quoted_args.append(compat_shlex_quote(a))
     return ' '.join(quoted_args)
 
@@ -3733,8 +3715,9 @@ def parse_resolution(s):
 
 
 def parse_bitrate(s):
-    if not isinstance(s, compat_str):
-        return
+    s = txt_or_none(s)
+    if not s:
+        return None
     mobj = re.search(r'\b(\d+)\s*kbps', s)
     if mobj:
         return int(mobj.group(1))
@@ -3822,18 +3805,17 @@ def base_url(url):
 
 
 def urljoin(base, path):
-    if isinstance(path, bytes):
-        path = path.decode('utf-8')
-    if not isinstance(path, compat_str) or not path:
+    path = _decode_compat_str(path, encoding='utf-8', or_none=True)
+    if not path:
         return None
     if re.match(r'^(?:[a-zA-Z][a-zA-Z0-9+-.]*:)?//', path):
         return path
-    if isinstance(base, bytes):
-        base = base.decode('utf-8')
-    if not isinstance(base, compat_str) or not re.match(
-            r'^(?:https?:)?//', base):
+    base = _decode_compat_str(base, encoding='utf-8', or_none=True)
+    if not base:
         return None
-    return compat_urllib_parse.urljoin(base, path)
+    return (
+        re.match(r'^(?:https?:)?//', base)
+        and compat_urllib_parse.urljoin(base, path))
 
 
 class HEADRequest(compat_urllib_request.Request):
@@ -3998,8 +3980,7 @@ def get_exe_version(exe, args=['--version'],
             stdout=subprocess.PIPE, stderr=subprocess.STDOUT))
     except OSError:
         return False
-    if isinstance(out, bytes):  # Python 2.x
-        out = out.decode('ascii', 'ignore')
+    out = _decode_compat_str(out, 'ascii', 'ignore')
     return detect_exe_version(out, version_re, unrecognized)
 
 
@@ -4218,8 +4199,8 @@ def lowercase_escape(s):
 
 def escape_rfc3986(s):
     """Escape non-ASCII characters as suggested by RFC 3986"""
-    if sys.version_info < (3, 0) and isinstance(s, compat_str):
-        s = s.encode('utf-8')
+    if sys.version_info < (3, 0):
+        s = _encode_compat_str(s, 'utf-8')
     # ensure unicode: after quoting, it can always be converted
     return compat_str(compat_urllib_parse.quote(s, b"%/;:@&=+$,!~*'()?#[]"))
 
@@ -4242,8 +4223,7 @@ def parse_qs(url, **kwargs):
 
 def read_batch_urls(batch_fd):
     def fixup(url):
-        if not isinstance(url, compat_str):
-            url = url.decode('utf-8', 'replace')
+        url = _decode_compat_str(url, 'utf-8', 'replace')
         BOM_UTF8 = '\xef\xbb\xbf'
         if url.startswith(BOM_UTF8):
             url = url[len(BOM_UTF8):]
@@ -4305,10 +4285,8 @@ def _multipart_encode_impl(data, boundary):
     out = b''
     for k, v in data.items():
         out += b'--' + boundary.encode('ascii') + b'\r\n'
-        if isinstance(k, compat_str):
-            k = k.encode('utf-8')
-        if isinstance(v, compat_str):
-            v = v.encode('utf-8')
+        k = _encode_compat_str(k, 'utf-8')
+        v = _encode_compat_str(v, 'utf-8')
         # RFC 2047 requires non-ASCII field names to be encoded, while RFC 7578
         # suggests sending UTF-8 directly. Firefox sends UTF-8, too
         content = b'Content-Disposition: form-data; name="' + k + b'"\r\n\r\n' + v + b'\r\n'
@@ -4435,8 +4413,26 @@ def merge_dicts(*dicts, **kwargs):
     return merged
 
 
-def encode_compat_str(string, encoding=preferredencoding(), errors='strict'):
-    return string if isinstance(string, compat_str) else compat_str(string, encoding, errors)
+# very poor choice of name, as if Python string encodings weren't confusing enough
+def encode_compat_str(s, encoding=preferredencoding(), errors='strict'):
+    assert isinstance(s, compat_basestring)
+    return s if isinstance(s, compat_str) else compat_str(s, encoding, errors)
+
+
+# what it could have been
+def _decode_compat_str(s, encoding=preferredencoding(), errors='strict', or_none=False):
+    if not or_none:
+        assert isinstance(s, compat_basestring)
+    return (
+        s if isinstance(s, compat_str)
+        else compat_str(s, encoding, errors) if isinstance(s, compat_basestring)
+        else None)
+
+
+# the real encode_compat_str, but only for internal use
+def _encode_compat_str(s, encoding=preferredencoding(), errors='strict'):
+    assert isinstance(s, compat_basestring)
+    return s.encode(encoding, errors) if isinstance(s, compat_str) else s
 
 
 US_RATINGS = {
@@ -4639,12 +4635,7 @@ def args_to_str(args):
 
 
 def error_to_compat_str(err):
-    err_str = str(err)
-    # On python 2 error byte string must be decoded with proper
-    # encoding rather than ascii
-    if sys.version_info[0] < 3:
-        err_str = err_str.decode(preferredencoding())
-    return err_str
+    return _decode_compat_str(str(err))
 
 
 def mimetype2ext(mt):