[utils] Support list of xpath in xpath_element
authorSergey M․ <dstftw@gmail.com>
Sat, 31 Oct 2015 16:39:44 +0000 (22:39 +0600)
committerSergey M․ <dstftw@gmail.com>
Sat, 31 Oct 2015 16:39:44 +0000 (22:39 +0600)
test/test_utils.py
youtube_dl/utils.py

index 0c34f0e551dcc78121f288e032ab14dcea5eabad..5a56ad7767898bc7f69dcd5fa7e74c582886e6c3 100644 (file)
@@ -275,9 +275,16 @@ class TestUtil(unittest.TestCase):
         p = xml.etree.ElementTree.SubElement(div, 'p')
         p.text = 'Foo'
         self.assertEqual(xpath_element(doc, 'div/p'), p)
+        self.assertEqual(xpath_element(doc, ['div/p']), p)
+        self.assertEqual(xpath_element(doc, ['div/bar', 'div/p']), p)
         self.assertEqual(xpath_element(doc, 'div/bar', default='default'), 'default')
+        self.assertEqual(xpath_element(doc, ['div/bar'], default='default'), 'default')
         self.assertTrue(xpath_element(doc, 'div/bar') is None)
+        self.assertTrue(xpath_element(doc, ['div/bar']) is None)
+        self.assertTrue(xpath_element(doc, ['div/bar'], 'div/baz') is None)
         self.assertRaises(ExtractorError, xpath_element, doc, 'div/bar', fatal=True)
+        self.assertRaises(ExtractorError, xpath_element, doc, ['div/bar'], fatal=True)
+        self.assertRaises(ExtractorError, xpath_element, doc, ['div/bar', 'div/baz'], fatal=True)
 
     def test_xpath_text(self):
         testxml = '''<root>
index 558c9c7d5a21c646a11720221fd1648c137ac68e..89c88a4d305a36c389f137926d6beec30f028878 100644 (file)
@@ -178,10 +178,19 @@ def xpath_with_ns(path, ns_map):
 
 
 def xpath_element(node, xpath, name=None, fatal=False, default=NO_DEFAULT):
-    if sys.version_info < (2, 7):  # Crazy 2.6
-        xpath = xpath.encode('ascii')
+    def _find_xpath(xpath):
+        if sys.version_info < (2, 7):  # Crazy 2.6
+            xpath = xpath.encode('ascii')
+        return node.find(xpath)
+
+    if isinstance(xpath, (str, compat_str)):
+        n = _find_xpath(xpath)
+    else:
+        for xp in xpath:
+            n = _find_xpath(xp)
+            if n is not None:
+                break
 
-    n = node.find(xpath)
     if n is None:
         if default is not NO_DEFAULT:
             return default