logo       
Google Custom Search
    AddThis Social Bookmark Button
-->

r26357 - lxml/trunk/src/lxml: msg#00101

Subject: r26357 - lxml/trunk/src/lxml
Author: scoder
Date: Wed Apr 26 12:38:22 2006
New Revision: 26357

Modified:
   lxml/trunk/src/lxml/nsclasses.pxi
   lxml/trunk/src/lxml/xslt.pxd
   lxml/trunk/src/lxml/xslt.pxi
Log:
rewrite of internal extension function registration
* register XSLT functions at module level directly in FunctionNamespace with 
xsltRegisterExtModuleFunction
* lookup XPath functions at XPath run-time rather than registering everything 
at call-time
* new 'regexp' keyword for XSLT() to switch off regexp support


Modified: lxml/trunk/src/lxml/nsclasses.pxi
==============================================================================
--- lxml/trunk/src/lxml/nsclasses.pxi   (original)
+++ lxml/trunk/src/lxml/nsclasses.pxi   Wed Apr 26 12:38:22 2006
@@ -53,11 +53,19 @@
 cdef class _NamespaceRegistry:
     "Dictionary-like registry for namespace implementations"
     cdef object _ns_uri
+    cdef object _ns_uri_utf
     cdef object _classes
     cdef object _extensions
     cdef object _xslt_elements
+    cdef char* _c_ns_uri_utf
     def __init__(self, ns_uri):
         self._ns_uri = ns_uri
+        if ns_uri is None:
+            self._ns_uri_utf = None
+            self._c_ns_uri_utf = NULL
+        else:
+            self._ns_uri_utf = _utf8(ns_uri)
+            self._c_ns_uri_utf = _cstr(self._ns_uri_utf)
         self._classes = {}
         self._extensions = {}
         self._xslt_elements = {}
@@ -95,14 +103,11 @@
         d[name_utf] = item
 
     def __getitem__(self, name):
-        name_utf = _utf8(name)
-        return self._get(name_utf)
-
-    cdef object _get(self, object name):
         cdef python.PyObject* dict_result
-        dict_result = python.PyDict_GetItem(self._classes, name)
+        name_utf = _utf8(name)
+        dict_result = python.PyDict_GetItem(self._classes, name_utf)
         if dict_result is NULL:
-            dict_result = python.PyDict_GetItem(self._extensions, name)
+            dict_result = python.PyDict_GetItem(self._extensions, name_utf)
         if dict_result is NULL:
             raise KeyError, "Name not registered."
         return <object>dict_result
@@ -130,34 +135,34 @@
             self._prefix_utf = _utf8(prefix)
             self._prefix = prefix
 
-    def __setitem__(self, name, item):
-        if not callable(item):
+    def __setitem__(self, name, function):
+        if not callable(function):
             raise NamespaceRegistryError, "Registered function must be 
callable."
         if name is None:
             name_utf = None
         else:
             name_utf = _utf8(name)
-        self._extensions[name_utf] = item
+        self._extensions[name_utf] = function
+        _register_global_xslt_function(self._c_ns_uri_utf, _cstr(name_utf))
 
-    cdef object _get(self, object name):
+    def __getitem__(self, name):
         cdef python.PyObject* dict_result
-        dict_result = python.PyDict_GetItem(self._extensions, name)
+        name_utf = _utf8(name)
+        dict_result = python.PyDict_GetItem(self._extensions, name_utf)
         if dict_result is NULL:
             raise KeyError, "Name not registered."
         return <object>dict_result
 
+    def clear(self):
+        cdef char* c_uri_utf
+        c_uri_utf = self._c_ns_uri_utf
+        for name_utf in self._extensions:
+            _unregister_global_xslt_function(c_uri_utf, _cstr(name_utf))
+        _NamespaceRegistry.clear(self)
+
     def __repr__(self):
         return "FunctionNamespace(%r)" % self._ns_uri
 
-cdef object _find_all_extensions():
-    "Internal lookup function to find all extension functions for XSLT/XPath."
-    cdef _NamespaceRegistry registry
-    ns_extensions = {}
-    for (ns_utf, registry) in __FUNCTION_NAMESPACE_REGISTRIES.iteritems():
-        if registry._extensions:
-            ns_extensions[ns_utf] = registry._extensions
-    return ns_extensions
-
 cdef object _find_all_extension_prefixes():
     "Internal lookup function to find all function prefixes for XSLT/XPath."
     cdef _FunctionNamespaceRegistry registry
@@ -167,25 +172,18 @@
             ns_prefixes[registry._prefix_utf] = ns_utf
     return ns_prefixes
 
-cdef _find_extensions(namespaces):
-    """Returns a dictionary that maps each namespace in the provided list to a
-    dictionary of name-function mappings defined under that namespace."""
+cdef object _find_extension(ns_uri_utf, name_utf):
     cdef python.PyObject* dict_result
-    cdef char* c_ns_utf
-    extension_dict = {}
-    for ns_uri in namespaces:
-        if ns_uri is None:
-            ns_utf = None
-        else:
-            ns_utf = _utf8(ns_uri)
-        dict_result = python.PyDict_GetItem(
-            __FUNCTION_NAMESPACE_REGISTRIES, ns_utf)
-        if dict_result is NULL:
-            continue
-        extensions = (<_NamespaceRegistry>dict_result)._extensions
-        if extensions:
-            python.PyDict_SetItem(extension_dict, ns_utf, extensions)
-    return extension_dict
+    dict_result = python.PyDict_GetItem(
+        __FUNCTION_NAMESPACE_REGISTRIES, ns_uri_utf)
+    if dict_result is NULL:
+        return None
+    extensions = (<_NamespaceRegistry>dict_result)._extensions
+    dict_result = python.PyDict_GetItem(extensions, name_utf)
+    if dict_result is NULL:
+        return None
+    else:
+        return <object>dict_result
 
 cdef object _find_element_class(char* c_namespace_utf,
                                 char* c_element_name_utf):

Modified: lxml/trunk/src/lxml/xslt.pxd
==============================================================================
--- lxml/trunk/src/lxml/xslt.pxd        (original)
+++ lxml/trunk/src/lxml/xslt.pxd        Wed Apr 26 12:38:22 2006
@@ -19,6 +19,9 @@
                                      char* name,
                                      char * URI,
                                      xmlXPathFunction function)
+    cdef int xsltRegisterExtModuleFunction(char* name, char* URI,
+                                           xmlXPathFunction function)
+    cdef int xsltUnregisterExtModuleFunction(char* name, char* URI)
 
 cdef extern from "libxslt/transform.h":
     cdef xmlDoc* xsltApplyStylesheet(xsltStylesheet* style, xmlDoc* doc,

Modified: lxml/trunk/src/lxml/xslt.pxi
==============================================================================
--- lxml/trunk/src/lxml/xslt.pxi        (original)
+++ lxml/trunk/src/lxml/xslt.pxi        Wed Apr 26 12:38:22 2006
@@ -36,8 +36,8 @@
     cdef object _extensions
     cdef object _namespaces
     cdef object _registered_namespaces
-    cdef object _extension_functions
     cdef object _utf_refs
+    cdef object _temp_last_function
     # for exception handling and temporary reference keeping:
     cdef object _temp_elements
     cdef object _temp_docs
@@ -46,8 +46,9 @@
     def __init__(self, namespaces, extensions):
         self._xpathCtxt = NULL
         self._utf_refs = {}
+        self._temp_last_function = (None, None, None)
 
-        # fix old format extensions
+        # convert old format extensions to UTF-8
         if isinstance(extensions, (list, tuple)):
             new_extensions = {}
             for extension in extensions:
@@ -65,7 +66,6 @@
         self._extensions = extensions
         self._namespaces = namespaces
         self._registered_namespaces = []
-        self._extension_functions = {}
         self._temp_elements = {}
         self._temp_docs = {}
 
@@ -88,28 +88,19 @@
     cdef _register_context(self, _Document doc, int allow_none_namespace):
         self._doc      = doc
         self._exc_info = None
+        self._temp_last_function = (None, None, None)
         namespaces = self._namespaces
         if namespaces is not None:
             self.registerNamespaces(namespaces)
-            extensions = _find_extensions(namespaces.values())
-        else:
-            extensions = _find_all_extensions()
-        if self._extensions is not None:
-            # add user provided extensions
-            extensions.update(self._extensions)
-        if extensions:
-            if not allow_none_namespace:
-                python.PyDict_DelItem(extensions, None)
-            self._registerExtensionFunctions(extensions)
+        xpath.xmlXPathRegisterFuncLookup(self._xpathCtxt, _function_check,
+                                         <python.PyObject*>self)
 
     cdef _unregister_context(self):
-        self._unregisterExtensionFunctions()
         self._unregisterNamespaces()
         self._free_context()
 
     cdef _free_context(self):
         del self._registered_namespaces[:]
-        python.PyDict_Clear(self._extension_functions)
         python.PyDict_Clear(self._utf_refs)
         self._doc = None
         if self._xpathCtxt is not NULL:
@@ -139,33 +130,25 @@
         for prefix_utf in self._registered_namespaces:
             xpath.xmlXPathRegisterNs(xpathCtxt, prefix_utf, NULL)
     
-    # extension functions (internal UTF-8 methods with leading '_')
-
-    def registerExtensionFunctions(self, extensions):
-        for ns_uri, extension in extensions.items():
-            for name, function in extension.items():
-                self._registerExtensionFunction(
-                    self._to_utf(ns_uri), self._to_utf(name), function)
-
-    def registerExtensionFunction(self, ns_uri, name, function):
-        self._registerExtensionFunction(
-            self._to_utf(ns_uri), self._to_utf(name), function)
+    # extension functions
 
-    cdef _registerExtensionFunctions(self, extensions_utf):
-        for ns_uri_utf, extension in extensions_utf.items():
-            for name_utf, function in extension.items():
-                self._registerExtensionFunction(ns_uri_utf, name_utf, function)
+    cdef _lookup_extension(self, ns_uri_utf, name_utf):
+        cdef python.PyObject* dict_result
+        if self._temp_last_function[0] == ns_uri_utf and \
+           self._temp_last_function[1] == name_utf:
+            return self._temp_last_function[2]
 
-    cdef _registerExtensionFunction(self, ns_uri_utf, name_utf, function):
-        self._contextRegisterExtensionFunction(ns_uri_utf, name_utf)
-        self._extension_functions[(ns_uri_utf, name_utf)] = function
-
-    cdef _unregisterExtensionFunctions(self):
-        for ns_uri_utf, name_utf in self._extension_functions:
-            self._contextUnregisterExtensionFunction(ns_uri_utf, name_utf)
+        dict_result = python.PyDict_GetItem(self._extensions, ns_uri_utf)
+        if dict_result is not NULL:
+            dict_result = python.PyDict_GetItem(<object>dict_result, name_utf)
+        if dict_result is not NULL:
+            function = <object>dict_result
+        else:
+            function = _find_extension(ns_uri_utf, name_utf)
 
-    def find_extension(self, ns_uri_utf, name_utf):
-        return self._extension_functions[(ns_uri_utf, name_utf)]
+        # store temporarily as it will be looked up again in the next call
+        self._temp_last_function = (ns_uri_utf, name_utf, function)
+        return function
 
     # Python reference keeping during XPath function evaluation
 
@@ -194,86 +177,27 @@
                 #print "Holding document:", <int>element._doc._c_doc
                 python.PyDict_SetItem(self._temp_docs, id(element._doc), 
element._doc)
 
-################################################################################
-# EXSLT regexp implementation
-
-cdef object RE_COMPILE
-RE_COMPILE = re.compile
-
-cdef class _ExsltRegExp:
-    cdef object _compile_map
-    def __init__(self):
-        self._compile_map = {}
-
-    cdef _make_string(self, value):
-        if python.PyString_Check(value) or python.PyUnicode_Check(value):
-            return value
-        else:
-            raise TypeError, "Invalid argument type %s" % type(value)
-
-    cdef _compile(self, rexp, ignore_case):
-        cdef python.PyObject* c_result
-        rexp = self._make_string(rexp)
-        key = (rexp, ignore_case)
-        c_result = python.PyDict_GetItem(self._compile_map, key)
-        if c_result is not NULL:
-            return <object>c_result
-        py_flags = re.UNICODE
-        if ignore_case:
-            py_flags = py_flags | re.IGNORECASE
-        rexp_compiled = RE_COMPILE(rexp, py_flags)
-        python.PyDict_SetItem(self._compile_map, key, rexp_compiled)
-        return rexp_compiled
-
-    def test(self, ctxt, s, rexp, flags=''):
-        flags = self._make_string(flags)
-        s = self._make_string(s)
-        rexpc = self._compile(rexp, 'i' in flags)
-        if rexpc.search(s) is None:
-            return False
-        else:
-            return True
-
-    def match(self, ctxt, s, rexp, flags=''):
-        flags = self._make_string(flags)
-        s = self._make_string(s)
-        rexpc = self._compile(rexp, 'i' in flags)
-        if 'g' in flags:
-            results = rexpc.findall(s)
-            if not results:
-                return ()
-            result_list = []
-            root = Element('matches')
-            for s_match in results:
-                elem = SubElement(root, 'match')
-                elem.text = s_match
-                python.PyList_Append(result_list, elem)
-            return result_list
-        else:
-            result = rexpc.search(s)
-            if result is None:
-                return ()
-            root = Element('match')
-            root.text = result.group()
-            return (root,)
+cdef xpath.xmlXPathFunction _function_check(void* ctxt, char* c_name, char* 
c_ns_uri):
+    cdef _BaseContext context
+    if c_name is NULL:
+        return NULL
+    if c_ns_uri is NULL:
+        ns_uri = None
+    else:
+        ns_uri = c_ns_uri
+    context = <_BaseContext>ctxt
+    if context._lookup_extension(ns_uri, c_name) is None:
+        return NULL
+    else:
+        return _xpath_function_call
 
-    def replace(self, ctxt, s, rexp, flags, replacement):
-        replacement = self._make_string(replacement)
-        flags = self._make_string(flags)
-        s = self._make_string(s)
-        rexpc = self._compile(rexp, 'i' in flags)
-        if 'g' in flags:
-            count = 0
-        else:
-            count = 1
-        return rexpc.sub(replacement, s, count)
+cdef void _register_global_xslt_function(char* ns_uri, char* name):
+    xslt.xsltRegisterExtModuleFunction(ns_uri, name, _xpath_function_call)
 
-    cdef void _register_exslt_regexp(self, _BaseContext context):
-        ns = "http://exslt.org/regular-expressions";
-        context._registerExtensionFunction(ns, "test",    self.test)
-        context._registerExtensionFunction(ns, "match",   self.match)
-        context._registerExtensionFunction(ns, "replace", self.replace)
+cdef void _unregister_global_xslt_function(char* ns_uri, char* name):
+    xslt.xsltUnRegisterExtModuleFunction(ns_uri, name)
 
+ 
 
################################################################################
 # XSLT
 
@@ -281,13 +205,17 @@
     cdef xslt.xsltTransformContext* _xsltCtxt
     def __init__(self, namespaces, extensions):
         self._xsltCtxt = NULL
+        if extensions and None in extensions:
+            raise XSLTExtensionError, "extensions must have non-empty 
namespaces"
         _BaseContext.__init__(self, namespaces, extensions)
 
-    cdef register_context(self, xslt.xsltTransformContext* xsltCtxt, _Document 
doc):
+    cdef register_context(self, xslt.xsltTransformContext* xsltCtxt,
+                               _Document doc):
         self._xsltCtxt = xsltCtxt
         self._set_xpath_context(xsltCtxt.xpathCtxt)
         self._register_context(doc, 0)
         xsltCtxt.xpathCtxt.userData = <void*>self
+        self._registerLocalExtensionFunctions()
 
     cdef free_context(self):
         cdef xslt.xsltTransformContext* xsltCtxt
@@ -298,19 +226,31 @@
         self._xsltCtxt = NULL
         xslt.xsltFreeTransformContext(xsltCtxt)
 
-    def _contextRegisterExtensionFunction(self, ns_uri_utf, name_utf):
-        if ns_uri_utf is None:
-            raise XSLTExtensionError, "extensions must have non-empty 
namespaces"
+    cdef _registerLocalExtensionFunction(self, ns_utf, name_utf, function):
+        extensions = self._extensions
+        if self._extensions is None:
+            self._extensions = {ns_utf:{name_utf:function}}
+        else:
+            if ns_utf in self._extensions:
+                self._extensions[ns_utf][name_utf] = function
+            else:
+                self._extensions[ns_utf] = ns_extensions = {name_utf:function}
         xslt.xsltRegisterExtFunction(
-            self._xsltCtxt, _cstr(name_utf), _cstr(ns_uri_utf),
-            _xpathCallback)
+            self._xsltCtxt, _cstr(name_utf), _cstr(ns_utf),
+            _xpath_function_call)
 
-    def _contextUnregisterExtensionFunction(self, ns_uri_utf, name_utf):
-        if ns_uri_utf is not None:
-            xslt.xsltRegisterExtFunction(
-                self._xsltCtxt, _cstr(name_utf), _cstr(ns_uri_utf),
-                _xpathCallback)
+    cdef _registerLocalExtensionFunctions(self):
+        cdef xslt.xsltTransformContext* xsltCtxt
+        if self._extensions is None:
+            return
+        xsltCtxt = self._xsltCtxt
+        for ns_uri_utf, extension in self._extensions.items():
+            for name_utf, function in extension.items():
+                xslt.xsltRegisterExtFunction(
+                    xsltCtxt, _cstr(name_utf), _cstr(ns_uri_utf),
+                    _xpath_function_call)
 
+cdef class _ExsltRegExp # forward declaration
 
 cdef class XSLT:
     """Turn a document into an XSLT object.
@@ -320,7 +260,7 @@
     cdef _ExsltRegExp _regexp
     cdef object _doc_url_utf
     
-    def __init__(self, xslt_input, extensions=None):
+    def __init__(self, xslt_input, extensions=None, regexp=True):
         # make a copy of the document as stylesheet needs to assume it
         # doesn't change
         cdef xslt.xsltStylesheet* c_style
@@ -353,7 +293,10 @@
         self._c_style = c_style
 
         self._context = _XSLTContext(None, extensions)
-        self._regexp  = _ExsltRegExp()
+        if regexp:
+            self._regexp  = _ExsltRegExp()
+        else:
+            self._regexp  = None
         # XXX is it worthwile to use xsltPrecomputeStylesheet here?
         
     def __dealloc__(self):
@@ -403,7 +346,8 @@
 
         self._context._release_temp_refs()
         self._context.register_context(transform_ctxt, input_doc)
-        self._regexp._register_exslt_regexp(self._context)
+        if self._regexp is not None:
+            self._regexp._register_in_context(self._context)
 
         c_result = xslt.xsltApplyStylesheetUser(self._c_style, c_doc, params,
                                                 NULL, NULL, transform_ctxt)
@@ -452,6 +396,86 @@
     return result
 
 
################################################################################
+# EXSLT regexp implementation
+
+cdef object RE_COMPILE
+RE_COMPILE = re.compile
+
+cdef class _ExsltRegExp:
+    cdef object _compile_map
+    def __init__(self):
+        self._compile_map = {}
+
+    cdef _make_string(self, value):
+        if python.PyString_Check(value) or python.PyUnicode_Check(value):
+            return value
+        else:
+            raise TypeError, "Invalid argument type %s" % type(value)
+
+    cdef _compile(self, rexp, ignore_case):
+        cdef python.PyObject* c_result
+        rexp = self._make_string(rexp)
+        key = (rexp, ignore_case)
+        c_result = python.PyDict_GetItem(self._compile_map, key)
+        if c_result is not NULL:
+            return <object>c_result
+        py_flags = re.UNICODE
+        if ignore_case:
+            py_flags = py_flags | re.IGNORECASE
+        rexp_compiled = RE_COMPILE(rexp, py_flags)
+        python.PyDict_SetItem(self._compile_map, key, rexp_compiled)
+        return rexp_compiled
+
+    def test(self, ctxt, s, rexp, flags=''):
+        flags = self._make_string(flags)
+        s = self._make_string(s)
+        rexpc = self._compile(rexp, 'i' in flags)
+        if rexpc.search(s) is None:
+            return False
+        else:
+            return True
+
+    def match(self, ctxt, s, rexp, flags=''):
+        flags = self._make_string(flags)
+        s = self._make_string(s)
+        rexpc = self._compile(rexp, 'i' in flags)
+        if 'g' in flags:
+            results = rexpc.findall(s)
+            if not results:
+                return ()
+            result_list = []
+            root = Element('matches')
+            for s_match in results:
+                elem = SubElement(root, 'match')
+                elem.text = s_match
+                python.PyList_Append(result_list, elem)
+            return result_list
+        else:
+            result = rexpc.search(s)
+            if result is None:
+                return ()
+            root = Element('match')
+            root.text = result.group()
+            return (root,)
+
+    def replace(self, ctxt, s, rexp, flags, replacement):
+        replacement = self._make_string(replacement)
+        flags = self._make_string(flags)
+        s = self._make_string(s)
+        rexpc = self._compile(rexp, 'i' in flags)
+        if 'g' in flags:
+            count = 0
+        else:
+            count = 1
+        return rexpc.sub(replacement, s, count)
+
+    cdef _register_in_context(self, _XSLTContext context):
+        ns = "http://exslt.org/regular-expressions";
+        context._registerLocalExtensionFunction(ns, "test",    self.test)
+        context._registerLocalExtensionFunction(ns, "match",   self.match)
+        context._registerLocalExtensionFunction(ns, "replace", self.replace)
+
+################################################################################
 # XPath
 
 cdef class _XPathContext(_BaseContext):
@@ -507,24 +531,6 @@
         xpath.xmlXPathRegisterVariable(
             self._xpathCtxt, _cstr(name_utf), _wrapXPathObject(value))
 
-    def _contextRegisterExtensionFunction(self, ns_uri_utf, name_utf):
-        if ns_uri_utf is not None:
-            xpath.xmlXPathRegisterFuncNS(
-                self._xpathCtxt, _cstr(name_utf), _cstr(ns_uri_utf),
-                _xpathCallback)
-        else:
-            xpath.xmlXPathRegisterFunc(
-                self._xpathCtxt, _cstr(name_utf),
-                _xpathCallback)
-
-    def _contextUnregisterExtensionFunction(self, ns_uri_utf, name_utf):
-        if ns_uri_utf is not None:
-            xpath.xmlXPathRegisterFuncNS(
-                self._xpathCtxt, _cstr(name_utf), _cstr(ns_uri_utf), NULL)
-        else:
-            xpath.xmlXPathRegisterFunc(
-                self._xpathCtxt, _cstr(name_utf), NULL)
-
 
 cdef class XPathEvaluatorBase:
     cdef _XPathContext _context
@@ -807,7 +813,7 @@
             raise NotImplementedError
     return result
 
-cdef void _xpathCallback(xpath.xmlXPathParserContext* ctxt, int nargs):
+cdef void _xpath_function_call(xpath.xmlXPathParserContext* ctxt, int nargs):
     cdef xpath.xmlXPathContext* rctxt
     cdef _Document doc
     cdef xpath.xmlXPathObject* obj
@@ -826,7 +832,7 @@
     extensions = <_BaseContext>(rctxt.userData)
 
     # lookup up the extension function in the context
-    f = extensions.find_extension(uri, name)
+    f = extensions._lookup_extension(uri, name)
 
     args = []
     doc = extensions._doc


<Prev in Thread] Current Thread [Next in Thread>