[compat] Improve compat_etree_iterfind for Py2.6
authordirkf <fieldhouse@gmx.net>
Tue, 28 May 2024 15:38:20 +0000 (16:38 +0100)
committerdirkf <fieldhouse@gmx.net>
Thu, 30 May 2024 14:46:36 +0000 (15:46 +0100)
Adapted from https://raw.githubusercontent.com/python/cpython/2.7/Lib/xml/etree/ElementPath.py

youtube_dl/compat.py

index d5485c7e802c8c612d19643e20ec5751a380d33c..0371896abc344a0c7fc6d038cd9e6ffa5b6a4e50 100644 (file)
@@ -2720,9 +2720,217 @@ if sys.version_info < (2, 7):
             xpath = xpath.encode('ascii')
         return xpath
 
-    def compat_etree_iterfind(element, match):
-        for from_ in element.findall(match):
-            yield from_
+    # further code below based on CPython 2.7 source
+    import functools
+
+    _xpath_tokenizer_re = re.compile(r'''(?x)
+        (                                   # (1)
+            '[^']*'|"[^"]*"|                # quoted strings, or
+            ::|//?|\.\.|\(\)|[/.*:[\]()@=]  # navigation specials
+        )|                                  # or (2)
+        ((?:\{[^}]+\})?[^/[\]()@=\s]+)|     # token: optional {ns}, no specials
+        \s+                                 # or white space
+    ''')
+
+    def _xpath_tokenizer(pattern, namespaces=None):
+        for token in _xpath_tokenizer_re.findall(pattern):
+            tag = token[1]
+            if tag and tag[0] != "{" and ":" in tag:
+                try:
+                    if not namespaces:
+                        raise KeyError
+                    prefix, uri = tag.split(":", 1)
+                    yield token[0], "{%s}%s" % (namespaces[prefix], uri)
+                except KeyError:
+                    raise SyntaxError("prefix %r not found in prefix map" % prefix)
+            else:
+                yield token
+
+    def _get_parent_map(context):
+        parent_map = context.parent_map
+        if parent_map is None:
+            context.parent_map = parent_map = {}
+            for p in context.root.getiterator():
+                for e in p:
+                    parent_map[e] = p
+        return parent_map
+
+    def _select(context, result, filter_fn=lambda *_: True):
+        for elem in result:
+            for e in elem:
+                if filter_fn(e, elem):
+                    yield e
+
+    def _prepare_child(next_, token):
+        tag = token[1]
+        return functools.partial(_select, filter_fn=lambda e, _: e.tag == tag)
+
+    def _prepare_star(next_, token):
+        return _select
+
+    def _prepare_self(next_, token):
+        return lambda _, result: (e for e in result)
+
+    def _prepare_descendant(next_, token):
+        token = next(next_)
+        if token[0] == "*":
+            tag = "*"
+        elif not token[0]:
+            tag = token[1]
+        else:
+            raise SyntaxError("invalid descendant")
+
+        def select(context, result):
+            for elem in result:
+                for e in elem.getiterator(tag):
+                    if e is not elem:
+                        yield e
+        return select
+
+    def _prepare_parent(next_, token):
+        def select(context, result):
+            # FIXME: raise error if .. is applied at toplevel?
+            parent_map = _get_parent_map(context)
+            result_map = {}
+            for elem in result:
+                if elem in parent_map:
+                    parent = parent_map[elem]
+                    if parent not in result_map:
+                        result_map[parent] = None
+                        yield parent
+        return select
+
+    def _prepare_predicate(next_, token):
+        signature = []
+        predicate = []
+        for token in next_:
+            if token[0] == "]":
+                break
+            if token[0] and token[0][:1] in "'\"":
+                token = "'", token[0][1:-1]
+            signature.append(token[0] or "-")
+            predicate.append(token[1])
+
+        def select(context, result, filter_fn=lambda _: True):
+            for elem in result:
+                if filter_fn(elem):
+                    yield elem
+
+        signature = "".join(signature)
+        # use signature to determine predicate type
+        if signature == "@-":
+            # [@attribute] predicate
+            key = predicate[1]
+            return functools.partial(
+                select, filter_fn=lambda el: el.get(key) is not None)
+        if signature == "@-='":
+            # [@attribute='value']
+            key = predicate[1]
+            value = predicate[-1]
+            return functools.partial(
+                select, filter_fn=lambda el: el.get(key) == value)
+        if signature == "-" and not re.match(r"\d+$", predicate[0]):
+            # [tag]
+            tag = predicate[0]
+            return functools.partial(
+                select, filter_fn=lambda el: el.find(tag) is not None)
+        if signature == "-='" and not re.match(r"\d+$", predicate[0]):
+            # [tag='value']
+            tag = predicate[0]
+            value = predicate[-1]
+
+            def itertext(el):
+                for e in el.getiterator():
+                    e = e.text
+                    if e:
+                        yield e
+
+            def select(context, result):
+                for elem in result:
+                    for e in elem.findall(tag):
+                        if "".join(itertext(e)) == value:
+                            yield elem
+                            break
+            return select
+        if signature == "-" or signature == "-()" or signature == "-()-":
+            # [index] or [last()] or [last()-index]
+            if signature == "-":
+                index = int(predicate[0]) - 1
+            else:
+                if predicate[0] != "last":
+                    raise SyntaxError("unsupported function")
+                if signature == "-()-":
+                    try:
+                        index = int(predicate[2]) - 1
+                    except ValueError:
+                        raise SyntaxError("unsupported expression")
+                else:
+                    index = -1
+
+            def select(context, result):
+                parent_map = _get_parent_map(context)
+                for elem in result:
+                    try:
+                        parent = parent_map[elem]
+                        # FIXME: what if the selector is "*" ?
+                        elems = list(parent.findall(elem.tag))
+                        if elems[index] is elem:
+                            yield elem
+                    except (IndexError, KeyError):
+                        pass
+            return select
+        raise SyntaxError("invalid predicate")
+
+    ops = {
+        "": _prepare_child,
+        "*": _prepare_star,
+        ".": _prepare_self,
+        "..": _prepare_parent,
+        "//": _prepare_descendant,
+        "[": _prepare_predicate,
+    }
+
+    _cache = {}
+
+    class _SelectorContext:
+        parent_map = None
+
+        def __init__(self, root):
+            self.root = root
+
+    ##
+    # Generate all matching objects.
+
+    def compat_etree_iterfind(elem, path, namespaces=None):
+        # compile selector pattern
+        if path[-1:] == "/":
+            path = path + "*"  # implicit all (FIXME: keep this?)
+        try:
+            selector = _cache[path]
+        except KeyError:
+            if len(_cache) > 100:
+                _cache.clear()
+            if path[:1] == "/":
+                raise SyntaxError("cannot use absolute path on element")
+            tokens = _xpath_tokenizer(path, namespaces)
+            selector = []
+            for token in tokens:
+                if token[0] == "/":
+                    continue
+                try:
+                    selector.append(ops[token[0]](tokens, token))
+                except StopIteration:
+                    raise SyntaxError("invalid path")
+            _cache[path] = selector
+        # execute selector pattern
+        result = [elem]
+        context = _SelectorContext(elem)
+        for select in selector:
+            result = select(context, result)
+        return result
+
+    # end of code based on CPython 2.7 source
+
 
 else:
     compat_xpath = lambda xpath: xpath