[utils] Multiple changes to base_n()
authorYen Chi Hsuan <yan12125@gmail.com>
Fri, 26 Feb 2016 19:19:50 +0000 (03:19 +0800)
committerYen Chi Hsuan <yan12125@gmail.com>
Fri, 26 Feb 2016 19:22:52 +0000 (03:22 +0800)
1. Renamed to encode_base_n()
2. Allow tables longer than 62 characters
3. Raise ValueError instead of AssertionError for invalid input data
4. Return the first character in the table instead of '0' for number 0
5. Add tests

test/test_utils.py
youtube_dl/utils.py

index d0736f435093403442c303b64e7024f4389a70ae..97587ad2f56215c5d57fd06c45b55a7c8ebf9d8e 100644 (file)
@@ -18,6 +18,7 @@ import xml.etree.ElementTree
 from youtube_dl.utils import (
     age_restricted,
     args_to_str,
+    encode_base_n,
     clean_html,
     DateRange,
     detect_exe_version,
@@ -802,5 +803,16 @@ The first line
             ohdave_rsa_encrypt(b'aa111222', e, N),
             '726664bd9a23fd0c70f9f1b84aab5e3905ce1e45a584e9cbcf9bcc7510338fc1986d6c599ff990d923aa43c51c0d9013cd572e13bc58f4ae48f2ed8c0b0ba881')
 
+    def test_encode_base_n(self):
+        self.assertEqual(encode_base_n(0, 30), '0')
+        self.assertEqual(encode_base_n(80, 30), '2k')
+
+        custom_table = '9876543210ZYXWVUTSRQPONMLKJIHGFEDCBA'
+        self.assertEqual(encode_base_n(0, 30, custom_table), '9')
+        self.assertEqual(encode_base_n(80, 30, custom_table), '7P')
+
+        self.assertRaises(ValueError, encode_base_n, 0, 70)
+        self.assertRaises(ValueError, encode_base_n, 0, 60, custom_table)
+
 if __name__ == '__main__':
     unittest.main()
index 756ad4fd10e28e69be4a6e6ea0eb5c67ffa44970..606977c5874ac613f0224815e8c91fc5ef62400a 100644 (file)
@@ -2621,15 +2621,17 @@ def ohdave_rsa_encrypt(data, exponent, modulus):
     return '%x' % encrypted
 
 
-def base_n(num, n, table=None):
-    if num == 0:
-        return '0'
-
+def encode_base_n(num, n, table=None):
     FULL_TABLE = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
-    assert n <= len(FULL_TABLE)
     if not table:
         table = FULL_TABLE[:n]
 
+    if n > len(table):
+        raise ValueError('base %d exceeds table length %d' % (n, len(table)))
+
+    if num == 0:
+        return table[0]
+
     ret = ''
     while num:
         ret = table[num % n] + ret
@@ -2649,7 +2651,7 @@ def decode_packed_codes(code):
 
     while count:
         count -= 1
-        base_n_count = base_n(count, base)
+        base_n_count = encode_base_n(count, base)
         symbol_table[base_n_count] = symbols[count] or base_n_count
 
     return re.sub(