helper.py (10744B)
1 from __future__ import unicode_literals 2 3 import errno 4 import io 5 import hashlib 6 import json 7 import os.path 8 import re 9 import types 10 import ssl 11 import sys 12 13 import youtube_dl.extractor 14 from youtube_dl import YoutubeDL 15 from youtube_dl.compat import ( 16 compat_os_name, 17 compat_str, 18 ) 19 from youtube_dl.utils import ( 20 preferredencoding, 21 write_string, 22 ) 23 24 25 def get_params(override=None): 26 PARAMETERS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 27 "parameters.json") 28 LOCAL_PARAMETERS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 29 "local_parameters.json") 30 with io.open(PARAMETERS_FILE, encoding='utf-8') as pf: 31 parameters = json.load(pf) 32 if os.path.exists(LOCAL_PARAMETERS_FILE): 33 with io.open(LOCAL_PARAMETERS_FILE, encoding='utf-8') as pf: 34 parameters.update(json.load(pf)) 35 if override: 36 parameters.update(override) 37 return parameters 38 39 40 def try_rm(filename): 41 """ Remove a file if it exists """ 42 try: 43 os.remove(filename) 44 except OSError as ose: 45 if ose.errno != errno.ENOENT: 46 raise 47 48 49 def report_warning(message): 50 ''' 51 Print the message to stderr, it will be prefixed with 'WARNING:' 52 If stderr is a tty file the 'WARNING:' will be colored 53 ''' 54 if sys.stderr.isatty() and compat_os_name != 'nt': 55 _msg_header = '\033[0;33mWARNING:\033[0m' 56 else: 57 _msg_header = 'WARNING:' 58 output = '%s %s\n' % (_msg_header, message) 59 if 'b' in getattr(sys.stderr, 'mode', '') or sys.version_info[0] < 3: 60 output = output.encode(preferredencoding()) 61 sys.stderr.write(output) 62 63 64 class FakeYDL(YoutubeDL): 65 def __init__(self, override=None): 66 # Different instances of the downloader can't share the same dictionary 67 # some test set the "sublang" parameter, which would break the md5 checks. 68 params = get_params(override=override) 69 super(FakeYDL, self).__init__(params, auto_init=False) 70 self.result = [] 71 72 def to_screen(self, s, skip_eol=None): 73 print(s) 74 75 def trouble(self, s, tb=None): 76 raise Exception(s) 77 78 def download(self, x): 79 self.result.append(x) 80 81 def expect_warning(self, regex): 82 # Silence an expected warning matching a regex 83 old_report_warning = self.report_warning 84 85 def report_warning(self, message): 86 if re.match(regex, message): 87 return 88 old_report_warning(message) 89 self.report_warning = types.MethodType(report_warning, self) 90 91 92 def gettestcases(include_onlymatching=False): 93 for ie in youtube_dl.extractor.gen_extractors(): 94 for tc in ie.get_testcases(include_onlymatching): 95 yield tc 96 97 98 md5 = lambda s: hashlib.md5(s.encode('utf-8')).hexdigest() 99 100 101 def expect_value(self, got, expected, field): 102 if isinstance(expected, compat_str) and expected.startswith('re:'): 103 match_str = expected[len('re:'):] 104 match_rex = re.compile(match_str) 105 106 self.assertTrue( 107 isinstance(got, compat_str), 108 'Expected a %s object, but got %s for field %s' % ( 109 compat_str.__name__, type(got).__name__, field)) 110 self.assertTrue( 111 match_rex.match(got), 112 'field %s (value: %r) should match %r' % (field, got, match_str)) 113 elif isinstance(expected, compat_str) and expected.startswith('startswith:'): 114 start_str = expected[len('startswith:'):] 115 self.assertTrue( 116 isinstance(got, compat_str), 117 'Expected a %s object, but got %s for field %s' % ( 118 compat_str.__name__, type(got).__name__, field)) 119 self.assertTrue( 120 got.startswith(start_str), 121 'field %s (value: %r) should start with %r' % (field, got, start_str)) 122 elif isinstance(expected, compat_str) and expected.startswith('contains:'): 123 contains_str = expected[len('contains:'):] 124 self.assertTrue( 125 isinstance(got, compat_str), 126 'Expected a %s object, but got %s for field %s' % ( 127 compat_str.__name__, type(got).__name__, field)) 128 self.assertTrue( 129 contains_str in got, 130 'field %s (value: %r) should contain %r' % (field, got, contains_str)) 131 elif isinstance(expected, type): 132 self.assertTrue( 133 isinstance(got, expected), 134 'Expected type %r for field %s, but got value %r of type %r' % (expected, field, got, type(got))) 135 elif isinstance(expected, dict) and isinstance(got, dict): 136 expect_dict(self, got, expected) 137 elif isinstance(expected, list) and isinstance(got, list): 138 self.assertEqual( 139 len(expected), len(got), 140 'Expect a list of length %d, but got a list of length %d for field %s' % ( 141 len(expected), len(got), field)) 142 for index, (item_got, item_expected) in enumerate(zip(got, expected)): 143 type_got = type(item_got) 144 type_expected = type(item_expected) 145 self.assertEqual( 146 type_expected, type_got, 147 'Type mismatch for list item at index %d for field %s, expected %r, got %r' % ( 148 index, field, type_expected, type_got)) 149 expect_value(self, item_got, item_expected, field) 150 else: 151 if isinstance(expected, compat_str) and expected.startswith('md5:'): 152 self.assertTrue( 153 isinstance(got, compat_str), 154 'Expected field %s to be a unicode object, but got value %r of type %r' % (field, got, type(got))) 155 got = 'md5:' + md5(got) 156 elif isinstance(expected, compat_str) and re.match(r'^(?:min|max)?count:\d+', expected): 157 self.assertTrue( 158 isinstance(got, (list, dict)), 159 'Expected field %s to be a list or a dict, but it is of type %s' % ( 160 field, type(got).__name__)) 161 op, _, expected_num = expected.partition(':') 162 expected_num = int(expected_num) 163 if op == 'mincount': 164 assert_func = assertGreaterEqual 165 msg_tmpl = 'Expected %d items in field %s, but only got %d' 166 elif op == 'maxcount': 167 assert_func = assertLessEqual 168 msg_tmpl = 'Expected maximum %d items in field %s, but got %d' 169 elif op == 'count': 170 assert_func = assertEqual 171 msg_tmpl = 'Expected exactly %d items in field %s, but got %d' 172 else: 173 assert False 174 assert_func( 175 self, len(got), expected_num, 176 msg_tmpl % (expected_num, field, len(got))) 177 return 178 self.assertEqual( 179 expected, got, 180 'Invalid value for field %s, expected %r, got %r' % (field, expected, got)) 181 182 183 def expect_dict(self, got_dict, expected_dict): 184 for info_field, expected in expected_dict.items(): 185 got = got_dict.get(info_field) 186 expect_value(self, got, expected, info_field) 187 188 189 def expect_info_dict(self, got_dict, expected_dict): 190 expect_dict(self, got_dict, expected_dict) 191 # Check for the presence of mandatory fields 192 if got_dict.get('_type') not in ('playlist', 'multi_video'): 193 for key in ('id', 'url', 'title', 'ext'): 194 self.assertTrue(got_dict.get(key), 'Missing mandatory field %s' % key) 195 # Check for mandatory fields that are automatically set by YoutubeDL 196 for key in ['webpage_url', 'extractor', 'extractor_key']: 197 self.assertTrue(got_dict.get(key), 'Missing field: %s' % key) 198 199 # Are checkable fields missing from the test case definition? 200 test_info_dict = dict((key, value if not isinstance(value, compat_str) or len(value) < 250 else 'md5:' + md5(value)) 201 for key, value in got_dict.items() 202 if value and key in ('id', 'title', 'description', 'uploader', 'upload_date', 'timestamp', 'uploader_id', 'location', 'age_limit')) 203 missing_keys = set(test_info_dict.keys()) - set(expected_dict.keys()) 204 if missing_keys: 205 def _repr(v): 206 if isinstance(v, compat_str): 207 return "'%s'" % v.replace('\\', '\\\\').replace("'", "\\'").replace('\n', '\\n') 208 else: 209 return repr(v) 210 info_dict_str = '' 211 if len(missing_keys) != len(expected_dict): 212 info_dict_str += ''.join( 213 ' %s: %s,\n' % (_repr(k), _repr(v)) 214 for k, v in test_info_dict.items() if k not in missing_keys) 215 216 if info_dict_str: 217 info_dict_str += '\n' 218 info_dict_str += ''.join( 219 ' %s: %s,\n' % (_repr(k), _repr(test_info_dict[k])) 220 for k in missing_keys) 221 write_string( 222 '\n\'info_dict\': {\n' + info_dict_str + '},\n', out=sys.stderr) 223 self.assertFalse( 224 missing_keys, 225 'Missing keys in test definition: %s' % ( 226 ', '.join(sorted(missing_keys)))) 227 228 229 def assertRegexpMatches(self, text, regexp, msg=None): 230 if hasattr(self, 'assertRegexp'): 231 return self.assertRegexp(text, regexp, msg) 232 else: 233 m = re.match(regexp, text) 234 if not m: 235 note = 'Regexp didn\'t match: %r not found' % (regexp) 236 if len(text) < 1000: 237 note += ' in %r' % text 238 if msg is None: 239 msg = note 240 else: 241 msg = note + ', ' + msg 242 self.assertTrue(m, msg) 243 244 245 def assertGreaterEqual(self, got, expected, msg=None): 246 if not (got >= expected): 247 if msg is None: 248 msg = '%r not greater than or equal to %r' % (got, expected) 249 self.assertTrue(got >= expected, msg) 250 251 252 def assertLessEqual(self, got, expected, msg=None): 253 if not (got <= expected): 254 if msg is None: 255 msg = '%r not less than or equal to %r' % (got, expected) 256 self.assertTrue(got <= expected, msg) 257 258 259 def assertEqual(self, got, expected, msg=None): 260 if not (got == expected): 261 if msg is None: 262 msg = '%r not equal to %r' % (got, expected) 263 self.assertTrue(got == expected, msg) 264 265 266 def expect_warnings(ydl, warnings_re): 267 real_warning = ydl.report_warning 268 269 def _report_warning(w): 270 if not any(re.search(w_re, w) for w_re in warnings_re): 271 real_warning(w) 272 273 ydl.report_warning = _report_warning 274 275 276 def http_server_port(httpd): 277 if os.name == 'java' and isinstance(httpd.socket, ssl.SSLSocket): 278 # In Jython SSLSocket is not a subclass of socket.socket 279 sock = httpd.socket.sock 280 else: 281 sock = httpd.socket 282 return sock.getsockname()[1]