[core,utils] Implement unsafe file extension mitigation
authordirkf <fieldhouse@gmx.net>
Sun, 30 Jun 2024 17:37:25 +0000 (18:37 +0100)
committerdirkf <fieldhouse@gmx.net>
Tue, 2 Jul 2024 14:38:50 +0000 (15:38 +0100)
* from https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-79w7-vh3h-8g4, thx grub4k

test/test_utils.py
youtube_dl/YoutubeDL.py
youtube_dl/utils.py

index de7fe80b8b68c214f88e0a7dfa461b1efa76fe21..2947cce7eb34c2d08a1d2e01912f534fa8f5e96c 100644 (file)
@@ -14,9 +14,11 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 import io
 import itertools
 import json
+import types
 import xml.etree.ElementTree
 
 from youtube_dl.utils import (
+    _UnsafeExtensionError,
     age_restricted,
     args_to_str,
     base_url,
@@ -270,6 +272,27 @@ class TestUtil(unittest.TestCase):
             expand_path('~/%s' % env('YOUTUBE_DL_EXPATH_PATH')),
             '%s/expanded' % compat_getenv('HOME'))
 
+    _uncommon_extensions = [
+        ('exe', 'abc.exe.ext'),
+        ('de', 'abc.de.ext'),
+        ('../.mp4', None),
+        ('..\\.mp4', None),
+    ]
+
+    def assertUnsafeExtension(self, ext=None):
+        assert_raises = self.assertRaises(_UnsafeExtensionError)
+        assert_raises.ext = ext
+        orig_exit = assert_raises.__exit__
+
+        def my_exit(self_, exc_type, exc_val, exc_tb):
+            did_raise = orig_exit(exc_type, exc_val, exc_tb)
+            if did_raise and assert_raises.ext is not None:
+                self.assertEqual(assert_raises.ext, assert_raises.exception.extension, 'Unsafe extension  not as unexpected')
+            return did_raise
+
+        assert_raises.__exit__ = types.MethodType(my_exit, assert_raises)
+        return assert_raises
+
     def test_prepend_extension(self):
         self.assertEqual(prepend_extension('abc.ext', 'temp'), 'abc.temp.ext')
         self.assertEqual(prepend_extension('abc.ext', 'temp', 'ext'), 'abc.temp.ext')
@@ -278,6 +301,19 @@ class TestUtil(unittest.TestCase):
         self.assertEqual(prepend_extension('.abc', 'temp'), '.abc.temp')
         self.assertEqual(prepend_extension('.abc.ext', 'temp'), '.abc.temp.ext')
 
+        # Test uncommon extensions
+        self.assertEqual(prepend_extension('abc.ext', 'bin'), 'abc.bin.ext')
+        for ext, result in self._uncommon_extensions:
+            with self.assertUnsafeExtension(ext):
+                prepend_extension('abc', ext)
+            if result:
+                self.assertEqual(prepend_extension('abc.ext', ext, 'ext'), result)
+            else:
+                with self.assertUnsafeExtension(ext):
+                    prepend_extension('abc.ext', ext, 'ext')
+            with self.assertUnsafeExtension(ext):
+                prepend_extension('abc.unexpected_ext', ext, 'ext')
+
     def test_replace_extension(self):
         self.assertEqual(replace_extension('abc.ext', 'temp'), 'abc.temp')
         self.assertEqual(replace_extension('abc.ext', 'temp', 'ext'), 'abc.temp')
@@ -286,6 +322,16 @@ class TestUtil(unittest.TestCase):
         self.assertEqual(replace_extension('.abc', 'temp'), '.abc.temp')
         self.assertEqual(replace_extension('.abc.ext', 'temp'), '.abc.temp')
 
+        # Test uncommon extensions
+        self.assertEqual(replace_extension('abc.ext', 'bin'), 'abc.unknown_video')
+        for ext, _ in self._uncommon_extensions:
+            with self.assertUnsafeExtension(ext):
+                replace_extension('abc', ext)
+            with self.assertUnsafeExtension(ext):
+                replace_extension('abc.ext', ext, 'ext')
+            with self.assertUnsafeExtension(ext):
+                replace_extension('abc.unexpected_ext', ext, 'ext')
+
     def test_subtitles_filename(self):
         self.assertEqual(subtitles_filename('abc.ext', 'en', 'vtt'), 'abc.en.vtt')
         self.assertEqual(subtitles_filename('abc.ext', 'en', 'vtt', 'ext'), 'abc.en.vtt')
index dad44435f0acc272ecfb3149b2caa7969de32324..c19501915e4ecbe905b0b0184bfc7cce3f970c9a 100755 (executable)
@@ -7,6 +7,7 @@ import collections
 import copy
 import datetime
 import errno
+import functools
 import io
 import itertools
 import json
@@ -53,6 +54,7 @@ from .compat import (
     compat_urllib_request_DataHandler,
 )
 from .utils import (
+    _UnsafeExtensionError,
     age_restricted,
     args_to_str,
     bug_reports_message,
@@ -129,6 +131,20 @@ if compat_os_name == 'nt':
     import ctypes
 
 
+def _catch_unsafe_file_extension(func):
+    @functools.wraps(func)
+    def wrapper(self, *args, **kwargs):
+        try:
+            return func(self, *args, **kwargs)
+        except _UnsafeExtensionError as error:
+            self.report_error(
+                '{0} found; to avoid damaging your system, this value is disallowed.'
+                ' If you believe this is an error{1}').format(
+                    error.message, bug_reports_message(','))
+
+    return wrapper
+
+
 class YoutubeDL(object):
     """YoutubeDL class.
 
@@ -1925,6 +1941,7 @@ class YoutubeDL(object):
         if self.params.get('forcejson', False):
             self.to_stdout(json.dumps(self.sanitize_info(info_dict)))
 
+    @_catch_unsafe_file_extension
     def process_info(self, info_dict):
         """Process a single resolved IE result."""
 
index 1af3e2b57da2ae699d2bd7d364c726b25f91978e..df203b97ab2c7662a25664ecf4248448ea386937 100644 (file)
@@ -1717,39 +1717,6 @@ TIMEZONE_NAMES = {
     'PST': -8, 'PDT': -7   # Pacific
 }
 
-
-class Namespace(object):
-    """Immutable namespace"""
-
-    def __init__(self, **kw_attr):
-        self.__dict__.update(kw_attr)
-
-    def __iter__(self):
-        return iter(self.__dict__.values())
-
-    @property
-    def items_(self):
-        return self.__dict__.items()
-
-
-MEDIA_EXTENSIONS = Namespace(
-    common_video=('avi', 'flv', 'mkv', 'mov', 'mp4', 'webm'),
-    video=('3g2', '3gp', 'f4v', 'mk3d', 'divx', 'mpg', 'ogv', 'm4v', 'wmv'),
-    common_audio=('aiff', 'alac', 'flac', 'm4a', 'mka', 'mp3', 'ogg', 'opus', 'wav'),
-    audio=('aac', 'ape', 'asf', 'f4a', 'f4b', 'm4b', 'm4p', 'm4r', 'oga', 'ogx', 'spx', 'vorbis', 'wma', 'weba'),
-    thumbnails=('jpg', 'png', 'webp'),
-    # storyboards=('mhtml', ),
-    subtitles=('srt', 'vtt', 'ass', 'lrc', 'ttml'),
-    manifests=('f4f', 'f4m', 'm3u8', 'smil', 'mpd'),
-)
-MEDIA_EXTENSIONS.video = MEDIA_EXTENSIONS.common_video + MEDIA_EXTENSIONS.video
-MEDIA_EXTENSIONS.audio = MEDIA_EXTENSIONS.common_audio + MEDIA_EXTENSIONS.audio
-
-KNOWN_EXTENSIONS = (
-    MEDIA_EXTENSIONS.video + MEDIA_EXTENSIONS.audio
-    + MEDIA_EXTENSIONS.manifests
-)
-
 # needed for sanitizing filenames in restricted mode
 ACCENT_CHARS = dict(zip('ÂÃÄÀÁÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖŐØŒÙÚÛÜŰÝÞßàáâãäåæçèéêëìíîïðñòóôõöőøœùúûüűýþÿ',
                         itertools.chain('AAAAAA', ['AE'], 'CEEEEIIIIDNOOOOOOO', ['OE'], 'UUUUUY', ['TH', 'ss'],
@@ -3977,19 +3944,22 @@ def parse_duration(s):
     return duration
 
 
-def prepend_extension(filename, ext, expected_real_ext=None):
+def _change_extension(prepend, filename, ext, expected_real_ext=None):
     name, real_ext = os.path.splitext(filename)
-    return (
-        '{0}.{1}{2}'.format(name, ext, real_ext)
-        if not expected_real_ext or real_ext[1:] == expected_real_ext
-        else '{0}.{1}'.format(filename, ext))
+    sanitize_extension = _UnsafeExtensionError.sanitize_extension
 
+    if not expected_real_ext or real_ext.partition('.')[0::2] == ('', expected_real_ext):
+        filename = name
+        if prepend and real_ext:
+            sanitize_extension(ext, prepend=prepend)
+            return ''.join((filename, '.', ext, real_ext))
 
-def replace_extension(filename, ext, expected_real_ext=None):
-    name, real_ext = os.path.splitext(filename)
-    return '{0}.{1}'.format(
-        name if not expected_real_ext or real_ext[1:] == expected_real_ext else filename,
-        ext)
+    # Mitigate path traversal and file impersonation attacks
+    return '.'.join((filename, sanitize_extension(ext)))
+
+
+prepend_extension = functools.partial(_change_extension, True)
+replace_extension = functools.partial(_change_extension, False)
 
 
 def check_executable(exe, args=[]):
@@ -6579,3 +6549,136 @@ def join_nonempty(*values, **kwargs):
     if from_dict is not None:
         values = (traverse_obj(from_dict, variadic(v)) for v in values)
     return delim.join(map(compat_str, filter(None, values)))
+
+
+class Namespace(object):
+    """Immutable namespace"""
+
+    def __init__(self, **kw_attr):
+        self.__dict__.update(kw_attr)
+
+    def __iter__(self):
+        return iter(self.__dict__.values())
+
+    @property
+    def items_(self):
+        return self.__dict__.items()
+
+
+MEDIA_EXTENSIONS = Namespace(
+    common_video=('avi', 'flv', 'mkv', 'mov', 'mp4', 'webm'),
+    video=('3g2', '3gp', 'f4v', 'mk3d', 'divx', 'mpg', 'ogv', 'm4v', 'wmv'),
+    common_audio=('aiff', 'alac', 'flac', 'm4a', 'mka', 'mp3', 'ogg', 'opus', 'wav'),
+    audio=('aac', 'ape', 'asf', 'f4a', 'f4b', 'm4b', 'm4p', 'm4r', 'oga', 'ogx', 'spx', 'vorbis', 'wma', 'weba'),
+    thumbnails=('jpg', 'png', 'webp'),
+    # storyboards=('mhtml', ),
+    subtitles=('srt', 'vtt', 'ass', 'lrc', 'ttml'),
+    manifests=('f4f', 'f4m', 'm3u8', 'smil', 'mpd'),
+)
+MEDIA_EXTENSIONS.video = MEDIA_EXTENSIONS.common_video + MEDIA_EXTENSIONS.video
+MEDIA_EXTENSIONS.audio = MEDIA_EXTENSIONS.common_audio + MEDIA_EXTENSIONS.audio
+
+KNOWN_EXTENSIONS = (
+    MEDIA_EXTENSIONS.video + MEDIA_EXTENSIONS.audio
+    + MEDIA_EXTENSIONS.manifests
+)
+
+
+class _UnsafeExtensionError(Exception):
+    """
+    Mitigation exception for unwanted file overwrite/path traversal
+    This should be caught in YoutubeDL.py with a warning
+
+    Ref: https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-79w7-vh3h-8g4j
+    """
+    _ALLOWED_EXTENSIONS = frozenset(itertools.chain(
+        (   # internal
+            'description',
+            'json',
+            'meta',
+            'orig',
+            'part',
+            'temp',
+            'uncut',
+            'unknown_video',
+            'ytdl',
+        ),
+        # video
+        MEDIA_EXTENSIONS.video, (
+            'avif',
+            'ismv',
+            'm2ts',
+            'm4s',
+            'mng',
+            'mpeg',
+            'qt',
+            'swf',
+            'ts',
+            'vp9',
+            'wvm',
+        ),
+        # audio
+        MEDIA_EXTENSIONS.audio, (
+            'isma',
+            'mid',
+            'mpga',
+            'ra',
+        ),
+        # image
+        MEDIA_EXTENSIONS.thumbnails, (
+            'bmp',
+            'gif',
+            'ico',
+            'heic',
+            'jng',
+            'jpeg',
+            'jxl',
+            'svg',
+            'tif',
+            'wbmp',
+        ),
+        # subtitle
+        MEDIA_EXTENSIONS.subtitles, (
+            'dfxp',
+            'fs',
+            'ismt',
+            'sami',
+            'scc',
+            'ssa',
+            'tt',
+        ),
+        # others
+        MEDIA_EXTENSIONS.manifests,
+        (
+            # not used in yt-dl
+            # *MEDIA_EXTENSIONS.storyboards,
+            # 'desktop',
+            # 'ism',
+            # 'm3u',
+            # 'sbv',
+            # 'swp',
+            # 'url',
+            # 'webloc',
+            # 'xml',
+        )))
+
+    def __init__(self, extension):
+        super(_UnsafeExtensionError, self).__init__('unsafe file extension: {0!r}'.format(extension))
+        self.extension = extension
+
+    @classmethod
+    def sanitize_extension(cls, extension, **kwargs):
+        # ... /, *, prepend=False
+        prepend = kwargs.get('prepend', False)
+
+        if '/' in extension or '\\' in extension:
+            raise cls(extension)
+
+        if not prepend:
+            last = extension.rpartition('.')[-1]
+            if last == 'bin':
+                extension = last = 'unknown_video'
+            if last.lower() not in cls._ALLOWED_EXTENSIONS:
+                raise cls(extension)
+
+        return extension