diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 20b6c3ce..914ba5d6 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -133,7 +133,7 @@ jobs: run: | source "$CONDA/etc/profile.d/conda.sh" conda activate test_mkl_umath - python -c "import mkl_umath, numpy as np; mkl_umath.use_in_numpy(); np.sin(np.linspace(0, 1, num=10**6));" + python -c "import mkl_umath, numpy as np; mkl_umath.patch_numpy_umath(); np.sin(np.linspace(0, 1, num=10**6));" - name: Run tests run: | @@ -328,7 +328,7 @@ jobs: run: | @ECHO ON conda activate mkl_umath_test - python -c "import mkl_umath, numpy as np; mkl_umath.use_in_numpy(); np.sin(np.linspace(0, 1, num=10**6));" + python -c "import mkl_umath, numpy as np; mkl_umath.patch_numpy_umath(); np.sin(np.linspace(0, 1, num=10**6));" - name: Run tests shell: cmd /C CALL {0} diff --git a/AGENTS.md b/AGENTS.md index 71aaa95e..e87f52a7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -7,13 +7,13 @@ Entry point for agent context in this repo. It provides: - `mkl_umath._ufuncs` — OneMKL-backed NumPy ufunc loops -- `mkl_umath._patch` — runtime patching interface (`use_in_numpy()`, `restore()`, `is_patched()`) +- `mkl_umath._patch_numpy` — runtime patching interface (`patch_numpy_umath` `restore_numpy_umath`, `is_patched()`) - Performance-optimized math operations (sin, cos, exp, log, etc.) using Intel MKL VM ## Key components - **Python interface:** `mkl_umath/__init__.py`, `_init_helper.py` - **Core C implementation:** `mkl_umath/src/` (ufuncsmodule.c, mkl_umath_loops.c.src) -- **Cython patch layer:** `mkl_umath/src/_patch.pyx` +- **Cython patch layer:** `mkl_umath/src/_patch_numpy.pyx` - **Code generation:** `generate_umath.py`, `generate_umath_doc.py` - **Build system:** CMake (CMakeLists.txt) + scikit-build @@ -50,9 +50,9 @@ CC=${CC:-icx} pip install --no-build-isolation --no-deps . # clang is also supp ## Usage ```python import mkl_umath -mkl_umath.use_in_numpy() # Patch NumPy to use MKL loops +mkl_umath.patch_numpy_umath() # Patch NumPy to use MKL loops # ... perform NumPy operations (now accelerated) ... -mkl_umath.restore() # Restore original NumPy loops +mkl_umath.restore_numpy_umath() # Restore original NumPy loops ``` ## How to work in this repo diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a00e86f..38eb6fb7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -135,13 +135,13 @@ if (UNIX) endif() install(TARGETS _ufuncs LIBRARY DESTINATION mkl_umath) -add_cython_target(_patch "mkl_umath/src/_patch.pyx" C OUTPUT_VAR _generated_src) -Python_add_library(_patch MODULE WITH_SOABI ${_generated_src}) -target_include_directories(_patch PRIVATE "mkl_umath/src/" ${Python_NumPy_INCLUDE_DIRS} ${Python_INCLUDE_DIRS}) -target_compile_definitions(_patch PUBLIC NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION) -target_link_libraries(_patch PRIVATE mkl_umath_loops) -set_target_properties(_patch PROPERTIES C_STANDARD 99) +add_cython_target(_patch_numpy "mkl_umath/src/_patch_numpy.pyx" C OUTPUT_VAR _generated_src) +Python_add_library(_patch_numpy MODULE WITH_SOABI ${_generated_src}) +target_include_directories(_patch_numpy PRIVATE "mkl_umath/src/" ${Python_NumPy_INCLUDE_DIRS} ${Python_INCLUDE_DIRS}) +target_compile_definitions(_patch_numpy PUBLIC NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION) +target_link_libraries(_patch_numpy PRIVATE mkl_umath_loops) +set_target_properties(_patch_numpy PROPERTIES C_STANDARD 99) if (UNIX) - set_target_properties(_patch PROPERTIES INSTALL_RPATH "$ORIGIN/../..;$ORIGIN/../../..;$ORIGIN") + set_target_properties(_patch_numpy PROPERTIES INSTALL_RPATH "$ORIGIN/../..;$ORIGIN/../../..;$ORIGIN") endif() -install(TARGETS _patch LIBRARY DESTINATION mkl_umath) +install(TARGETS _patch_numpy LIBRARY DESTINATION mkl_umath) diff --git a/conda-recipe-cf/meta.yaml b/conda-recipe-cf/meta.yaml index 92d5a5db..f678f9af 100644 --- a/conda-recipe-cf/meta.yaml +++ b/conda-recipe-cf/meta.yaml @@ -44,7 +44,7 @@ test: imports: - mkl_umath - mkl_umath._ufuncs - - mkl_umath._patch + - mkl_umath._patch_numpy about: home: http://github.com/IntelPython/mkl_umath diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 38ea112f..4b15337c 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -48,7 +48,7 @@ test: imports: - mkl_umath - mkl_umath._ufuncs - - mkl_umath._patch + - mkl_umath._patch_numpy about: home: http://github.com/IntelPython/mkl_umath diff --git a/mkl_umath/AGENTS.md b/mkl_umath/AGENTS.md index 06b55865..f3c38bed 100644 --- a/mkl_umath/AGENTS.md +++ b/mkl_umath/AGENTS.md @@ -14,9 +14,9 @@ Core MKL-backed ufunc implementation: Python interface, Cython patching, and C/M ## Patching API ```python -mkl_umath.use_in_numpy() # Replace NumPy loops with MKL -mkl_umath.restore() # Restore original NumPy loops -mkl_umath.is_patched() # Check patch status +mkl_umath.patch_numpy_umath() # Replace NumPy loops with MKL +mkl_umath.restore_numpy_umath() # Restore original NumPy loops +mkl_umath.is_patched() # Check patch status ``` ## Development guardrails @@ -31,6 +31,6 @@ mkl_umath.is_patched() # Check patch status - Docstrings: dual NumPy 1.x/2.x support via separate docstring modules ## Notes -- `_patch.pyx` is Cython; changes require Cython rebuild +- `_patch_numpy.pyx` is Cython; changes require Cython rebuild - MKL VM loops in `src/mkl_umath_loops.c.src` - `src/ufuncsmodule.c` — NumPy ufunc registration and dispatch diff --git a/mkl_umath/__init__.py b/mkl_umath/__init__.py index 0477d5a2..d0603992 100644 --- a/mkl_umath/__init__.py +++ b/mkl_umath/__init__.py @@ -29,7 +29,14 @@ """ from . import _init_helper -from ._patch import is_patched, mkl_umath, restore, use_in_numpy +from ._patch_numpy import ( + is_patched, + mkl_umath, + patch_numpy_umath, + restore, + restore_numpy_umath, + use_in_numpy, +) from ._ufuncs import * from ._version import __version__ diff --git a/mkl_umath/src/AGENTS.md b/mkl_umath/src/AGENTS.md index 83d978e7..9ff52362 100644 --- a/mkl_umath/src/AGENTS.md +++ b/mkl_umath/src/AGENTS.md @@ -7,7 +7,7 @@ C/Cython implementation layer: MKL VM integration, ufunc loops, and NumPy patchi - **ufuncsmodule.h** — ufunc module public headers - **mkl_umath_loops.c.src** — MKL VM loop implementations (template, ~60k LOC) - **mkl_umath_loops.h.src** — loop function declarations (template) -- **_patch.pyx** — Cython patching layer (runtime NumPy loop replacement) +- **_patch_numpy.pyx** — Cython patching layer (runtime NumPy loop replacement) - **fast_loop_macros.h** — loop generation macros - **blocking_utils.h** — blocking/chunking utilities for large arrays @@ -21,15 +21,16 @@ C/Cython implementation layer: MKL VM integration, ufunc loops, and NumPy patchi - Blocking strategy: chunk large arrays for cache efficiency - Error handling: MKL VM status → NumPy error state -## Patching mechanism (_patch.pyx) -- Cython extension exposing `use_in_numpy()`, `restore()`, `is_patched()` +## Patching mechanism (_patch_numpy.pyx) +- Cython extension exposing `patch_numpy_umath()`, `restore_numpy_umath()`, + `is_patched()` - Replaces function pointers in NumPy's ufunc loop tables - Thread-safe: guards against concurrent patching - Reversible: stores original pointers for restoration ## Build output - `mkl_umath_loops.c` → shared library (libmkl_umath_loops.so/.dll) -- `_patch.pyx` → Python extension (_patch.*.so) +- `_patch_numpy.pyx` → Python extension (_patch.*.so) - `ufuncsmodule.c` + `__umath_generated.c` → `_ufuncs` extension ## Development notes diff --git a/mkl_umath/src/_patch.pyx b/mkl_umath/src/_patch.pyx deleted file mode 100644 index 8d2d299d..00000000 --- a/mkl_umath/src/_patch.pyx +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) 2019, Intel Corporation -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of Intel Corporation nor the names of its contributors -# may be used to endorse or promote products derived from this software -# without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -# distutils: language = c -# cython: language_level=3 - -import mkl_umath._ufuncs as mu - -cimport numpy as cnp - -import numpy as np - -from libc.stdlib cimport free, malloc - -cnp.import_umath() - - -ctypedef struct function_info: - cnp.PyUFuncGenericFunction original_function - cnp.PyUFuncGenericFunction patch_function - int* signature - - -cdef class patch: - cdef int functions_count - cdef function_info* functions - cdef bint _is_patched - - functions_dict = dict() - - def __cinit__(self): - cdef int pi, oi - - self._is_patched = False - - umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)] - self.functions_count = 0 - for umath in umaths: - mkl_umath = getattr(mu, umath) - self.functions_count += mkl_umath.ntypes - - self.functions = malloc( - self.functions_count * sizeof(function_info) - ) - - func_number = 0 - for umath in umaths: - patch_umath = getattr(mu, umath) - c_patch_umath = patch_umath - c_orig_umath = getattr(np, umath) - nargs = c_patch_umath.nargs - for pi in range(c_patch_umath.ntypes): - oi = 0 - while oi < c_orig_umath.ntypes: - found = True - for i in range(c_patch_umath.nargs): - if ( - c_patch_umath.types[pi * nargs + i] - != c_orig_umath.types[oi * nargs + i] - ): - found = False - break - if found is True: - break - oi = oi + 1 - if oi < c_orig_umath.ntypes: - self.functions[func_number].original_function = ( - c_orig_umath.functions[oi] - ) - self.functions[func_number].patch_function = ( - c_patch_umath.functions[pi] - ) - self.functions[func_number].signature = ( - malloc(nargs * sizeof(int)) - ) - for i in range(nargs): - self.functions[func_number].signature[i] = ( - c_patch_umath.types[pi * nargs + i] - ) - self.functions_dict[(umath, patch_umath.types[pi])] = ( - func_number - ) - func_number = func_number + 1 - else: - raise RuntimeError( - f"Unable to find original function for: {umath} " - f"{patch_umath.types[pi]}" - ) - - def __dealloc__(self): - for i in range(self.functions_count): - free(self.functions[i].signature) - free(self.functions) - - def do_patch(self): - cdef int _res - cdef cnp.PyUFuncGenericFunction temp - cdef cnp.PyUFuncGenericFunction function - cdef int* signature - - for func in self.functions_dict: - np_umath = getattr(np, func[0]) - index = self.functions_dict[func] - function = self.functions[index].patch_function - signature = self.functions[index].signature - # TODO: check res, 0 means success, -1 means error - _res = cnp.PyUFunc_ReplaceLoopBySignature( - np_umath, function, signature, &temp - ) - - self._is_patched = True - - def do_unpatch(self): - cdef int _res - cdef cnp.PyUFuncGenericFunction temp - cdef cnp.PyUFuncGenericFunction function - cdef int* signature - - for func in self.functions_dict: - np_umath = getattr(np, func[0]) - index = self.functions_dict[func] - function = self.functions[index].original_function - signature = self.functions[index].signature - # TODO: check res, 0 means success, -1 means error - _res = cnp.PyUFunc_ReplaceLoopBySignature( - np_umath, function, signature, &temp - ) - - self._is_patched = False - - def is_patched(self): - return self._is_patched - -from threading import local as threading_local - -_tls = threading_local() - - -def _is_tls_initialized(): - return getattr(_tls, "initialized", False) - - -def _initialize_tls(): - _tls.patch = patch() - _tls.initialized = True - - -def use_in_numpy(): - """ - Enables using of mkl_umath in Numpy. - - Examples - -------- - >>> import mkl_umath - >>> mkl_umath.is_patched() - # False - - >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # True - - >>> mkl_umath.restore() # Disable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # False - - """ - if not _is_tls_initialized(): - _initialize_tls() - _tls.patch.do_patch() - - -def restore(): - """ - Disables using of mkl_umath in Numpy. - - Examples - -------- - >>> import mkl_umath - >>> mkl_umath.is_patched() - # False - - >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # True - - >>> mkl_umath.restore() # Disable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # False - - """ - if not _is_tls_initialized(): - _initialize_tls() - _tls.patch.do_unpatch() - - -def is_patched(): - """ - Returns whether Numpy has been patched with mkl_umath. - - Examples - -------- - >>> import mkl_umath - >>> mkl_umath.is_patched() - # False - - >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # True - - >>> mkl_umath.restore() # Disable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # False - - """ - if not _is_tls_initialized(): - _initialize_tls() - return _tls.patch.is_patched() - - -from contextlib import ContextDecorator - - -class mkl_umath(ContextDecorator): - """ - Context manager and decorator to temporarily patch NumPy ufuncs - with MKL-based implementations. - - Examples - -------- - >>> import mkl_umath - >>> mkl_umath.is_patched() - # False - - >>> with mkl_umath.mkl_umath(): # Enable mkl_umath in Numpy - >>> print(mkl_umath.is_patched()) - # True - - >>> mkl_umath.is_patched() - # False - - """ - def __enter__(self): - use_in_numpy() - return self - - def __exit__(self, *exc): - restore() - return False diff --git a/mkl_umath/src/_patch_numpy.pyx b/mkl_umath/src/_patch_numpy.pyx new file mode 100644 index 00000000..cbd5ed6d --- /dev/null +++ b/mkl_umath/src/_patch_numpy.pyx @@ -0,0 +1,369 @@ +# Copyright (c) 2019, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# distutils: language = c +# cython: language_level=3 + +import warnings +from contextlib import ContextDecorator +from threading import Lock, local + +import mkl_umath._ufuncs as mu + +cimport numpy as cnp + +import numpy as np + +from libc.stdlib cimport free, malloc + +cnp.import_umath() + +ctypedef struct function_info: + cnp.PyUFuncGenericFunction original_function + cnp.PyUFuncGenericFunction patch_function + int* signature + + +cdef class _patch_impl: + cdef int functions_count + cdef function_info* functions + + functions_dict = dict() + + def __cinit__(self): + cdef int pi, oi + + umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)] + self.functions_count = 0 + for umath in umaths: + mkl_umath_func = getattr(mu, umath) + self.functions_count += mkl_umath_func.ntypes + + self.functions = malloc( + self.functions_count * sizeof(function_info) + ) + + func_number = 0 + for umath in umaths: + patch_umath = getattr(mu, umath) + c_patch_umath = patch_umath + c_orig_umath = getattr(np, umath) + nargs = c_patch_umath.nargs + for pi in range(c_patch_umath.ntypes): + oi = 0 + while oi < c_orig_umath.ntypes: + found = True + for i in range(c_patch_umath.nargs): + if ( + c_patch_umath.types[pi * nargs + i] + != c_orig_umath.types[oi * nargs + i] + ): + found = False + break + if found is True: + break + oi = oi + 1 + if oi < c_orig_umath.ntypes: + self.functions[func_number].original_function = ( + c_orig_umath.functions[oi] + ) + self.functions[func_number].patch_function = ( + c_patch_umath.functions[pi] + ) + self.functions[func_number].signature = ( + malloc(nargs * sizeof(int)) + ) + for i in range(nargs): + self.functions[func_number].signature[i] = ( + c_patch_umath.types[pi * nargs + i] + ) + self.functions_dict[(umath, patch_umath.types[pi])] = ( + func_number + ) + func_number = func_number + 1 + else: + raise RuntimeError( + f"Unable to find original function for: {umath} " + f"{patch_umath.types[pi]}" + ) + + def __dealloc__(self): + for i in range(self.functions_count): + free(self.functions[i].signature) + free(self.functions) + + cdef int _replace_loop( + self, + object func, + cnp.PyUFuncGenericFunction function, + ) except -1: + cdef int res + cdef cnp.PyUFuncGenericFunction temp + cdef int* signature + + np_umath = getattr(np, func[0]) + index = self.functions_dict[func] + signature = self.functions[index].signature + res = cnp.PyUFunc_ReplaceLoopBySignature( + np_umath, function, signature, &temp + ) + return res + + def do_patch(self): + cdef int index + + for func in self.functions_dict: + index = self.functions_dict[func] + if self._replace_loop( + func, self.functions[index].patch_function + ) != 0: + raise RuntimeError( + f"Failed to patch {func[0]} with signature {func[1]}. " + "NumPy may be partially restored or in an invalid state." + ) + + def do_unpatch(self): + cdef int index + + for func in self.functions_dict: + index = self.functions_dict[func] + if self._replace_loop( + func, self.functions[index].original_function + ) != 0: + raise RuntimeError( + f"Failed to restore {func[0]} with signature {func[1]}. " + "NumPy may be partially restored or in an invalid state." + ) + + +class _GlobalPatch: + def __init__(self): + self._lock = Lock() + self._patch_count = 0 + self._tls = local() + self._patcher = None + + def do_patch(self, verbose=False): + with self._lock: + local_count = getattr(self._tls, "local_count", 0) + if self._patch_count == 0: + if verbose: + print( + "Now patching NumPy ufuncs with mkl_umath loops." + ) + print( + "Please direct bug reports to " + "https://github.com/IntelPython/mkl_umath" + ) + if self._patcher is None: + # lazy initialization of the patcher to save memory + self._patcher = _patch_impl() + self._patcher.do_patch() + + self._patch_count += 1 + self._tls.local_count = local_count + 1 + + def do_restore(self, verbose=False): + with self._lock: + local_count = getattr(self._tls, "local_count", 0) + if local_count <= 0: + if verbose: + print( + "Warning: restore_numpy_umath called more times than " + "patch_numpy_umath in this thread." + ) + return + + next_patch_count = self._patch_count - 1 + if next_patch_count == 0: + if verbose: + print("Now restoring original NumPy loops.") + self._patcher.do_unpatch() + + self._tls.local_count -= 1 + self._patch_count = next_patch_count + + def is_patched(self): + with self._lock: + return self._patch_count > 0 + + +_patch = _GlobalPatch() + + +def patch_numpy_umath(verbose=False): + """ + Patch NumPy's ufuncs with mkl_umath's loops. + + Parameters + ---------- + verbose : bool, optional + print message when starting the patching process. + + Notes + ----- + This function uses reference-counted semantics. Each call increments a + global patch counter. Restoration requires a matching number of calls + between `patch_numpy_umath` and `restore_numpy_umath`. + + ⚠️ Warning + ------------------------- + If used in a multi-threaded program, ALL concurrent threads executing NumPy + operations must either have applied the patch prior to execution, or run + entirely within the `mkl_umath` context manager. Executing standard NumPy + calls in one thread while another thread is actively patching or unpatching + will lead to undefined behavior at best, and segmentation faults at worst. + For this reason, it is recommended to prefer the `mkl_umath` context + manager. + + Examples + -------- + >>> import mkl_umath + >>> mkl_umath.is_patched() + # False + + >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # True + + >>> mkl_umath.restore() # Disable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # False + """ + _patch.do_patch(verbose=verbose) + + +def restore_numpy_umath(verbose=False): + """ + Restore NumPy's ufuncs to the original loops. + + Parameters + ---------- + verbose : bool, optional + print message when starting restoration process. + + Notes + ----- + This function uses reference-counted semantics. Each call decrements a + global patch counter. Restoration requires a matching number of calls + between `patch_numpy_umath` and `restore_numpy_umath`. + + ⚠️ Warning + ------------------------- + If used in a multi-threaded program, ALL concurrent threads executing NumPy + operations must either have applied the patch prior to execution, or run + entirely within the `mkl_umath` context manager. Executing standard NumPy + calls in one thread while another thread is actively patching or unpatching + will lead to undefined behavior at best, and segmentation faults at worst. + For this reason, it is recommended to prefer the `mkl_umath` context + manager. + + Examples + -------- + >>> import mkl_umath + >>> mkl_umath.is_patched() + # False + + >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # True + + >>> mkl_umath.restore() # Disable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # False + """ + _patch.do_restore(verbose=verbose) + + +def use_in_numpy(): + """ + Deprecated alias for patch_numpy_umath. + + See patch_numpy_umath for details and examples. + """ + warnings.warn( + "use_in_numpy is deprecated since mkl_random 0.4.0 and will be removed " + "in a future release. Use `patch_numpy_umath` instead.", + DeprecationWarning, + stacklevel=2, + ) + patch_numpy_umath() + + +def restore(): + """ + Deprecated alias for restore_numpy_umath. + + See restore_numpy_umath for details and examples. + """ + warnings.warn( + "restore is deprecated since mkl_random 0.4.0 and will be " + "removed in a future release. Use `restore_numpy_umath` instead.", + DeprecationWarning, + stacklevel=2, + ) + restore_numpy_umath() + + +def is_patched(): + """Return True if NumPy umath loops have been patched by mkl_umath.""" + return _patch.is_patched() + + +class mkl_umath(ContextDecorator): + """ + Context manager and decorator to temporarily patch NumPy ufuncs + with MKL-based implementations. + + ⚠️ Warning + ------------------------- + If used in a multi-threaded program, ALL concurrent threads executing NumPy + operations must either have applied the patch prior to execution, or run + entirely within the `mkl_umath` context manager. Executing standard NumPy + calls in one thread while another thread is actively patching or unpatching + will lead to undefined behavior at best, and segmentation faults at worst. + For this reason, it is recommended to prefer the `mkl_umath` context + manager. + + Examples + -------- + >>> import mkl_umath + >>> mkl_umath.is_patched() + # False + + >>> with mkl_umath.mkl_umath(): # Enable mkl_umath in Numpy + >>> print(mkl_umath.is_patched()) + # True + + >>> mkl_umath.is_patched() + # False + """ + def __enter__(self): + patch_numpy_umath() + return self + + def __exit__(self, *exc): + restore_numpy_umath() + return False diff --git a/mkl_umath/tests/AGENTS.md b/mkl_umath/tests/AGENTS.md index 74571825..f962fafc 100644 --- a/mkl_umath/tests/AGENTS.md +++ b/mkl_umath/tests/AGENTS.md @@ -7,7 +7,7 @@ Unit tests for MKL-backed ufuncs and NumPy patching. ## Test coverage - Ufunc correctness: compare MKL loops vs NumPy reference -- Patching: `use_in_numpy()`, `restore()`, `is_patched()` state transitions +- Patching: `patch_numpy_umath()`, `restore_numpy_umath()`, `is_patched()` state transitions - Edge cases: NaN, Inf, empty arrays, large arrays - Dtype coverage: float32, float64, complex64, complex128 diff --git a/mkl_umath/tests/__init__.py b/mkl_umath/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mkl_umath/tests/test_basic.py b/mkl_umath/tests/test_basic.py index 3fd5a35c..116a25ff 100644 --- a/mkl_umath/tests/test_basic.py +++ b/mkl_umath/tests/test_basic.py @@ -26,7 +26,6 @@ import numpy as np import pytest -import mkl_umath._patch as mp # pylint: disable=no-name-in-module import mkl_umath._ufuncs as mu # pylint: disable=no-name-in-module np.random.seed(42) @@ -63,8 +62,8 @@ def get_args(args_str, size, low, high): mkl_cases = {} fall_back_cases = {} for umath in umaths: - mkl_umath = getattr(mu, umath) - types = mkl_umath.types + _mkl_umath = getattr(mu, umath) + types = _mkl_umath.types size_mkl = 8192 + 1 for type_ in types: args_str = type_[: type_.find("->")] @@ -102,10 +101,10 @@ def get_id(val): def test_mkl_umath(case): umath, _ = case args = test_mkl[case] - mkl_umath = getattr(mu, umath) + _mkl_umath = getattr(mu, umath) np_umath = getattr(np, umath) - mkl_res = mkl_umath(*args) + mkl_res = _mkl_umath(*args) np_res = np_umath(*args) assert np.allclose(mkl_res, np_res), f"Results for '{umath}' do not match" @@ -115,10 +114,10 @@ def test_mkl_umath(case): def test_fall_back_umath(case): umath, _ = case args = test_fall_back[case] - mkl_umath = getattr(mu, umath) + _mkl_umath = getattr(mu, umath) np_umath = getattr(np, umath) - mkl_res = mkl_umath(*args) + mkl_res = _mkl_umath(*args) np_res = np_umath(*args) assert np.allclose(mkl_res, np_res), f"Results for '{umath}' do not match" @@ -190,14 +189,3 @@ def test_reduce_complex(func, dtype): assert np.allclose( mkl_res, np_res ), f"Results for '{func}[reduce]' do not match" - - -def test_patch(): - mp.restore() - assert not mp.is_patched() - - mp.use_in_numpy() # Enable mkl_umath in Numpy - assert mp.is_patched() - - mp.restore() # Disable mkl_umath in Numpy - assert not mp.is_patched() diff --git a/mkl_umath/tests/test_patching.py b/mkl_umath/tests/test_patching.py new file mode 100644 index 00000000..3fa58fdd --- /dev/null +++ b/mkl_umath/tests/test_patching.py @@ -0,0 +1,102 @@ +# Copyright (c) 2019, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import mkl_umath + +import contextlib +import sys + +from dataclasses import dataclass +from io import StringIO + +import pytest + + +@dataclass +class CapturedOutput(): + stdout: str + + +@contextlib.contextmanager +def capture_output(): + old_stdout = sys.stdout + capturer = StringIO() + sys.stdout = capturer + output = CapturedOutput(stdout="") + yield output + sys.stdout = old_stdout + output.stdout = capturer.getvalue() + + +def test_patch_basic(): + mkl_umath.restore_numpy_umath() + assert not mkl_umath.is_patched() + + mkl_umath.patch_numpy_umath() # Enable mkl_umath in Numpy + assert mkl_umath.is_patched() + + mkl_umath.restore_numpy_umath() # Disable mkl_umath in Numpy + assert not mkl_umath.is_patched() + + +def test_patch_redundant_patching(): + assert not mkl_umath.is_patched() + + mkl_umath.patch_numpy_umath() + mkl_umath.patch_numpy_umath() + + assert mkl_umath.is_patched() + + mkl_umath.restore_numpy_umath() + assert mkl_umath.is_patched() + + mkl_umath.restore_numpy_umath() + assert not mkl_umath.is_patched() + + +def test_patch_reentrant(): + assert not mkl_umath.is_patched() + + with mkl_umath.mkl_umath(): + assert mkl_umath.is_patched() + + with mkl_umath.mkl_umath(): + assert mkl_umath.is_patched() + + assert mkl_umath.is_patched() + + assert not mkl_umath.is_patched() + + +def test_patch_verbose(): + assert not mkl_umath.is_patched() + + with capture_output() as output: + mkl_umath.patch_numpy_umath(verbose=True) + assert output.stdout + assert mkl_umath.is_patched() + + mkl_umath.restore_numpy_umath() + assert not mkl_umath.is_patched()