From c07eef0d891e8e41cfa57260595c9731a33390c1 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 9 Mar 2026 21:19:49 -0700 Subject: [PATCH 1/4] Update umath patching * Add safety checks to patching * Add warnings related to multi-threaded programs * Implement _GlobalPatch wrapper class which implements patching globally * Rename _patch module to _patch_numpy * Rename patching functions --- .github/workflows/conda-package.yml | 4 +- AGENTS.md | 8 +- CMakeLists.txt | 16 +- conda-recipe-cf/meta.yaml | 2 +- conda-recipe/meta.yaml | 2 +- mkl_umath/AGENTS.md | 8 +- mkl_umath/__init__.py | 9 +- mkl_umath/src/AGENTS.md | 9 +- .../src/{_patch.pyx => _patch_numpy.pyx} | 224 ++++++++++-------- mkl_umath/tests/AGENTS.md | 2 +- mkl_umath/tests/test_basic.py | 26 +- 11 files changed, 174 insertions(+), 136 deletions(-) rename mkl_umath/src/{_patch.pyx => _patch_numpy.pyx} (54%) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 20b6c3ce..914ba5d6 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -133,7 +133,7 @@ jobs: run: | source "$CONDA/etc/profile.d/conda.sh" conda activate test_mkl_umath - python -c "import mkl_umath, numpy as np; mkl_umath.use_in_numpy(); np.sin(np.linspace(0, 1, num=10**6));" + python -c "import mkl_umath, numpy as np; mkl_umath.patch_numpy_umath(); np.sin(np.linspace(0, 1, num=10**6));" - name: Run tests run: | @@ -328,7 +328,7 @@ jobs: run: | @ECHO ON conda activate mkl_umath_test - python -c "import mkl_umath, numpy as np; mkl_umath.use_in_numpy(); np.sin(np.linspace(0, 1, num=10**6));" + python -c "import mkl_umath, numpy as np; mkl_umath.patch_numpy_umath(); np.sin(np.linspace(0, 1, num=10**6));" - name: Run tests shell: cmd /C CALL {0} diff --git a/AGENTS.md b/AGENTS.md index 71aaa95e..e87f52a7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -7,13 +7,13 @@ Entry point for agent context in this repo. It provides: - `mkl_umath._ufuncs` — OneMKL-backed NumPy ufunc loops -- `mkl_umath._patch` — runtime patching interface (`use_in_numpy()`, `restore()`, `is_patched()`) +- `mkl_umath._patch_numpy` — runtime patching interface (`patch_numpy_umath` `restore_numpy_umath`, `is_patched()`) - Performance-optimized math operations (sin, cos, exp, log, etc.) using Intel MKL VM ## Key components - **Python interface:** `mkl_umath/__init__.py`, `_init_helper.py` - **Core C implementation:** `mkl_umath/src/` (ufuncsmodule.c, mkl_umath_loops.c.src) -- **Cython patch layer:** `mkl_umath/src/_patch.pyx` +- **Cython patch layer:** `mkl_umath/src/_patch_numpy.pyx` - **Code generation:** `generate_umath.py`, `generate_umath_doc.py` - **Build system:** CMake (CMakeLists.txt) + scikit-build @@ -50,9 +50,9 @@ CC=${CC:-icx} pip install --no-build-isolation --no-deps . # clang is also supp ## Usage ```python import mkl_umath -mkl_umath.use_in_numpy() # Patch NumPy to use MKL loops +mkl_umath.patch_numpy_umath() # Patch NumPy to use MKL loops # ... perform NumPy operations (now accelerated) ... -mkl_umath.restore() # Restore original NumPy loops +mkl_umath.restore_numpy_umath() # Restore original NumPy loops ``` ## How to work in this repo diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a00e86f..38eb6fb7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -135,13 +135,13 @@ if (UNIX) endif() install(TARGETS _ufuncs LIBRARY DESTINATION mkl_umath) -add_cython_target(_patch "mkl_umath/src/_patch.pyx" C OUTPUT_VAR _generated_src) -Python_add_library(_patch MODULE WITH_SOABI ${_generated_src}) -target_include_directories(_patch PRIVATE "mkl_umath/src/" ${Python_NumPy_INCLUDE_DIRS} ${Python_INCLUDE_DIRS}) -target_compile_definitions(_patch PUBLIC NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION) -target_link_libraries(_patch PRIVATE mkl_umath_loops) -set_target_properties(_patch PROPERTIES C_STANDARD 99) +add_cython_target(_patch_numpy "mkl_umath/src/_patch_numpy.pyx" C OUTPUT_VAR _generated_src) +Python_add_library(_patch_numpy MODULE WITH_SOABI ${_generated_src}) +target_include_directories(_patch_numpy PRIVATE "mkl_umath/src/" ${Python_NumPy_INCLUDE_DIRS} ${Python_INCLUDE_DIRS}) +target_compile_definitions(_patch_numpy PUBLIC NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION) +target_link_libraries(_patch_numpy PRIVATE mkl_umath_loops) +set_target_properties(_patch_numpy PROPERTIES C_STANDARD 99) if (UNIX) - set_target_properties(_patch PROPERTIES INSTALL_RPATH "$ORIGIN/../..;$ORIGIN/../../..;$ORIGIN") + set_target_properties(_patch_numpy PROPERTIES INSTALL_RPATH "$ORIGIN/../..;$ORIGIN/../../..;$ORIGIN") endif() -install(TARGETS _patch LIBRARY DESTINATION mkl_umath) +install(TARGETS _patch_numpy LIBRARY DESTINATION mkl_umath) diff --git a/conda-recipe-cf/meta.yaml b/conda-recipe-cf/meta.yaml index 92d5a5db..f678f9af 100644 --- a/conda-recipe-cf/meta.yaml +++ b/conda-recipe-cf/meta.yaml @@ -44,7 +44,7 @@ test: imports: - mkl_umath - mkl_umath._ufuncs - - mkl_umath._patch + - mkl_umath._patch_numpy about: home: http://github.com/IntelPython/mkl_umath diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 38ea112f..4b15337c 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -48,7 +48,7 @@ test: imports: - mkl_umath - mkl_umath._ufuncs - - mkl_umath._patch + - mkl_umath._patch_numpy about: home: http://github.com/IntelPython/mkl_umath diff --git a/mkl_umath/AGENTS.md b/mkl_umath/AGENTS.md index 06b55865..f3c38bed 100644 --- a/mkl_umath/AGENTS.md +++ b/mkl_umath/AGENTS.md @@ -14,9 +14,9 @@ Core MKL-backed ufunc implementation: Python interface, Cython patching, and C/M ## Patching API ```python -mkl_umath.use_in_numpy() # Replace NumPy loops with MKL -mkl_umath.restore() # Restore original NumPy loops -mkl_umath.is_patched() # Check patch status +mkl_umath.patch_numpy_umath() # Replace NumPy loops with MKL +mkl_umath.restore_numpy_umath() # Restore original NumPy loops +mkl_umath.is_patched() # Check patch status ``` ## Development guardrails @@ -31,6 +31,6 @@ mkl_umath.is_patched() # Check patch status - Docstrings: dual NumPy 1.x/2.x support via separate docstring modules ## Notes -- `_patch.pyx` is Cython; changes require Cython rebuild +- `_patch_numpy.pyx` is Cython; changes require Cython rebuild - MKL VM loops in `src/mkl_umath_loops.c.src` - `src/ufuncsmodule.c` — NumPy ufunc registration and dispatch diff --git a/mkl_umath/__init__.py b/mkl_umath/__init__.py index 0477d5a2..148bbc5b 100644 --- a/mkl_umath/__init__.py +++ b/mkl_umath/__init__.py @@ -29,10 +29,13 @@ """ from . import _init_helper -from ._patch import is_patched, mkl_umath, restore, use_in_numpy +from ._patch_numpy import ( + is_patched, + mkl_umath, + patch_numpy_umath, + restore_numpy_umath, +) from ._ufuncs import * from ._version import __version__ -# TODO: add __all__ with public API and remove star imports - del _init_helper diff --git a/mkl_umath/src/AGENTS.md b/mkl_umath/src/AGENTS.md index 83d978e7..9ff52362 100644 --- a/mkl_umath/src/AGENTS.md +++ b/mkl_umath/src/AGENTS.md @@ -7,7 +7,7 @@ C/Cython implementation layer: MKL VM integration, ufunc loops, and NumPy patchi - **ufuncsmodule.h** — ufunc module public headers - **mkl_umath_loops.c.src** — MKL VM loop implementations (template, ~60k LOC) - **mkl_umath_loops.h.src** — loop function declarations (template) -- **_patch.pyx** — Cython patching layer (runtime NumPy loop replacement) +- **_patch_numpy.pyx** — Cython patching layer (runtime NumPy loop replacement) - **fast_loop_macros.h** — loop generation macros - **blocking_utils.h** — blocking/chunking utilities for large arrays @@ -21,15 +21,16 @@ C/Cython implementation layer: MKL VM integration, ufunc loops, and NumPy patchi - Blocking strategy: chunk large arrays for cache efficiency - Error handling: MKL VM status → NumPy error state -## Patching mechanism (_patch.pyx) -- Cython extension exposing `use_in_numpy()`, `restore()`, `is_patched()` +## Patching mechanism (_patch_numpy.pyx) +- Cython extension exposing `patch_numpy_umath()`, `restore_numpy_umath()`, + `is_patched()` - Replaces function pointers in NumPy's ufunc loop tables - Thread-safe: guards against concurrent patching - Reversible: stores original pointers for restoration ## Build output - `mkl_umath_loops.c` → shared library (libmkl_umath_loops.so/.dll) -- `_patch.pyx` → Python extension (_patch.*.so) +- `_patch_numpy.pyx` → Python extension (_patch.*.so) - `ufuncsmodule.c` + `__umath_generated.c` → `_ufuncs` extension ## Development notes diff --git a/mkl_umath/src/_patch.pyx b/mkl_umath/src/_patch_numpy.pyx similarity index 54% rename from mkl_umath/src/_patch.pyx rename to mkl_umath/src/_patch_numpy.pyx index 8d2d299d..6cd2990e 100644 --- a/mkl_umath/src/_patch.pyx +++ b/mkl_umath/src/_patch_numpy.pyx @@ -26,6 +26,9 @@ # distutils: language = c # cython: language_level=3 +from contextlib import ContextDecorator +from threading import Lock, local + import mkl_umath._ufuncs as mu cimport numpy as cnp @@ -36,30 +39,26 @@ from libc.stdlib cimport free, malloc cnp.import_umath() - ctypedef struct function_info: cnp.PyUFuncGenericFunction original_function cnp.PyUFuncGenericFunction patch_function int* signature -cdef class patch: +cdef class _patch_impl: cdef int functions_count cdef function_info* functions - cdef bint _is_patched functions_dict = dict() def __cinit__(self): cdef int pi, oi - self._is_patched = False - umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)] self.functions_count = 0 for umath in umaths: - mkl_umath = getattr(mu, umath) - self.functions_count += mkl_umath.ntypes + mkl_umath_func = getattr(mu, umath) + self.functions_count += mkl_umath_func.ntypes self.functions = malloc( self.functions_count * sizeof(function_info) @@ -115,7 +114,7 @@ cdef class patch: free(self.functions) def do_patch(self): - cdef int _res + cdef int res cdef cnp.PyUFuncGenericFunction temp cdef cnp.PyUFuncGenericFunction function cdef int* signature @@ -126,14 +125,16 @@ cdef class patch: function = self.functions[index].patch_function signature = self.functions[index].signature # TODO: check res, 0 means success, -1 means error - _res = cnp.PyUFunc_ReplaceLoopBySignature( + res = cnp.PyUFunc_ReplaceLoopBySignature( np_umath, function, signature, &temp ) - - self._is_patched = True + if res != 0: + raise RuntimeError( + f"Failed to patch {func[0]} with signature {func[1]}" + ) def do_unpatch(self): - cdef int _res + cdef int res cdef cnp.PyUFuncGenericFunction temp cdef cnp.PyUFuncGenericFunction function cdef int* signature @@ -143,103 +144,127 @@ cdef class patch: index = self.functions_dict[func] function = self.functions[index].original_function signature = self.functions[index].signature - # TODO: check res, 0 means success, -1 means error - _res = cnp.PyUFunc_ReplaceLoopBySignature( + res = cnp.PyUFunc_ReplaceLoopBySignature( np_umath, function, signature, &temp ) - - self._is_patched = False + if res != 0: + raise RuntimeError( + f"Failed to restore {func[0]} with signature {func[1]}" + ) + + +class _GlobalPatch: + def __init__(self): + self._lock = Lock() + self._patch_count = 0 + self._tls = local() + self._patcher = None + + def do_patch(self, verbose=False): + with self._lock: + local_count = getattr(self._tls, "local_count", 0) + if self._patch_count == 0: + if verbose: + print( + "Now patching NumPy FFT submodule with mkl_fft NumPy " + "interface." + ) + print( + "Please direct bug reports to " + "https://github.com/IntelPython/mkl_fft" + ) + if self._patcher is None: + # lazy initialization of the patcher to save memory + self._patcher = _patch_impl() + self._patcher.do_patch() + + self._patch_count += 1 + self._tls.local_count = local_count + 1 + + def do_restore(self, verbose=False): + with self._lock: + local_count = getattr(self._tls, "local_count", 0) + if local_count <= 0: + if verbose: + print( + "Warning: restore_numpy_umath called more times than " + "patch_numpy_fft in this thread." + ) + return + self._tls.local_count -= 1 + self._patch_count -= 1 + if self._patch_count == 0: + if verbose: + print("Now restoring original NumPy loops.") + self._patcher.do_unpatch() def is_patched(self): - return self._is_patched - -from threading import local as threading_local - -_tls = threading_local() + with self._lock: + return self._patch_count > 0 -def _is_tls_initialized(): - return getattr(_tls, "initialized", False) +_patch = _GlobalPatch() -def _initialize_tls(): - _tls.patch = patch() - _tls.initialized = True - - -def use_in_numpy(): +def patch_numpy_umath(verbose=False): """ - Enables using of mkl_umath in Numpy. - - Examples - -------- - >>> import mkl_umath - >>> mkl_umath.is_patched() - # False - - >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # True - - >>> mkl_umath.restore() # Disable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # False - + Patch NumPy's ufuncs with mkl_umath's loops. + + Parameters + ---------- + verbose : bool, optional + print message when starting the patching process. + + Notes + ----- + This function uses reference-counted semantics. Each call increments a + global patch counter. Restoration requires a matching number of calls + between `patch_numpy_umath` and `restore_numpy_umath`. + + ⚠️ Warning + ------------------------- + If used in a multi-threaded program, ALL concurrent threads executing NumPy + operations must either have applied the patch prior to execution, or run + entirely within the `mkl_umath` context manager. Executing standard NumPy + calls in one thread while another thread is actively patching or unpatching + will lead to undefined behavior at best, and segmentation faults at worst. + For this reason, it is recommended to prefer the `mkl_umath` context + manager. """ - if not _is_tls_initialized(): - _initialize_tls() - _tls.patch.do_patch() + _patch.do_patch(verbose=verbose) -def restore(): +def restore_numpy_umath(verbose=False): """ - Disables using of mkl_umath in Numpy. - - Examples - -------- - >>> import mkl_umath - >>> mkl_umath.is_patched() - # False - - >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # True - - >>> mkl_umath.restore() # Disable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # False - + Restore NumPy's ufuncs to the original loops. + + Parameters + ---------- + verbose : bool, optional + print message when starting restoration process. + + Notes + ----- + This function uses reference-counted semantics. Each call decrements a + global patch counter. Restoration requires a matching number of calls + between `patch_numpy_umath` and `restore_numpy_umath`. + + ⚠️ Warning + ------------------------- + If used in a multi-threaded program, ALL concurrent threads executing NumPy + operations must either have applied the patch prior to execution, or run + entirely within the `mkl_umath` context manager. Executing standard NumPy + calls in one thread while another thread is actively patching or unpatching + will lead to undefined behavior at best, and segmentation faults at worst. + For this reason, it is recommended to prefer the `mkl_umath` context + manager. """ - if not _is_tls_initialized(): - _initialize_tls() - _tls.patch.do_unpatch() + _patch.do_restore(verbose=verbose) def is_patched(): - """ - Returns whether Numpy has been patched with mkl_umath. - - Examples - -------- - >>> import mkl_umath - >>> mkl_umath.is_patched() - # False - - >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # True - - >>> mkl_umath.restore() # Disable mkl_umath in Numpy - >>> mkl_umath.is_patched() - # False - - """ - if not _is_tls_initialized(): - _initialize_tls() - return _tls.patch.is_patched() - - -from contextlib import ContextDecorator + """Return True if NumPy umath loops have been patched by mkl_umath.""" + return _patch.is_patched() class mkl_umath(ContextDecorator): @@ -247,6 +272,16 @@ class mkl_umath(ContextDecorator): Context manager and decorator to temporarily patch NumPy ufuncs with MKL-based implementations. + ⚠️ Warning + ------------------------- + If used in a multi-threaded program, ALL concurrent threads executing NumPy + operations must either have applied the patch prior to execution, or run + entirely within the `mkl_umath` context manager. Executing standard NumPy + calls in one thread while another thread is actively patching or unpatching + will lead to undefined behavior at best, and segmentation faults at worst. + For this reason, it is recommended to prefer the `mkl_umath` context + manager. + Examples -------- >>> import mkl_umath @@ -259,12 +294,11 @@ class mkl_umath(ContextDecorator): >>> mkl_umath.is_patched() # False - """ def __enter__(self): - use_in_numpy() + patch_numpy_umath() return self def __exit__(self, *exc): - restore() + restore_numpy_umath() return False diff --git a/mkl_umath/tests/AGENTS.md b/mkl_umath/tests/AGENTS.md index 74571825..f962fafc 100644 --- a/mkl_umath/tests/AGENTS.md +++ b/mkl_umath/tests/AGENTS.md @@ -7,7 +7,7 @@ Unit tests for MKL-backed ufuncs and NumPy patching. ## Test coverage - Ufunc correctness: compare MKL loops vs NumPy reference -- Patching: `use_in_numpy()`, `restore()`, `is_patched()` state transitions +- Patching: `patch_numpy_umath()`, `restore_numpy_umath()`, `is_patched()` state transitions - Edge cases: NaN, Inf, empty arrays, large arrays - Dtype coverage: float32, float64, complex64, complex128 diff --git a/mkl_umath/tests/test_basic.py b/mkl_umath/tests/test_basic.py index 3fd5a35c..ff53c146 100644 --- a/mkl_umath/tests/test_basic.py +++ b/mkl_umath/tests/test_basic.py @@ -26,7 +26,7 @@ import numpy as np import pytest -import mkl_umath._patch as mp # pylint: disable=no-name-in-module +import mkl_umath import mkl_umath._ufuncs as mu # pylint: disable=no-name-in-module np.random.seed(42) @@ -63,8 +63,8 @@ def get_args(args_str, size, low, high): mkl_cases = {} fall_back_cases = {} for umath in umaths: - mkl_umath = getattr(mu, umath) - types = mkl_umath.types + _mkl_umath = getattr(mu, umath) + types = _mkl_umath.types size_mkl = 8192 + 1 for type_ in types: args_str = type_[: type_.find("->")] @@ -102,10 +102,10 @@ def get_id(val): def test_mkl_umath(case): umath, _ = case args = test_mkl[case] - mkl_umath = getattr(mu, umath) + _mkl_umath = getattr(mu, umath) np_umath = getattr(np, umath) - mkl_res = mkl_umath(*args) + mkl_res = _mkl_umath(*args) np_res = np_umath(*args) assert np.allclose(mkl_res, np_res), f"Results for '{umath}' do not match" @@ -115,10 +115,10 @@ def test_mkl_umath(case): def test_fall_back_umath(case): umath, _ = case args = test_fall_back[case] - mkl_umath = getattr(mu, umath) + _mkl_umath = getattr(mu, umath) np_umath = getattr(np, umath) - mkl_res = mkl_umath(*args) + mkl_res = _mkl_umath(*args) np_res = np_umath(*args) assert np.allclose(mkl_res, np_res), f"Results for '{umath}' do not match" @@ -193,11 +193,11 @@ def test_reduce_complex(func, dtype): def test_patch(): - mp.restore() - assert not mp.is_patched() + mkl_umath.restore_numpy_umath() + assert not mkl_umath.is_patched() - mp.use_in_numpy() # Enable mkl_umath in Numpy - assert mp.is_patched() + mkl_umath.patch_numpy_umath() # Enable mkl_umath in Numpy + assert mkl_umath.is_patched() - mp.restore() # Disable mkl_umath in Numpy - assert not mp.is_patched() + mkl_umath.restore_numpy_umath() # Disable mkl_umath in Numpy + assert not mkl_umath.is_patched() From 76c0e70a9c5e54625202a1259fc0305b98f659f2 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 10 Mar 2026 08:47:49 -0700 Subject: [PATCH 2/4] address review comments --- mkl_umath/__init__.py | 2 + mkl_umath/src/_patch_numpy.pyx | 68 +++++++++++++++++++++++++++++++--- 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/mkl_umath/__init__.py b/mkl_umath/__init__.py index 148bbc5b..bdff1c97 100644 --- a/mkl_umath/__init__.py +++ b/mkl_umath/__init__.py @@ -38,4 +38,6 @@ from ._ufuncs import * from ._version import __version__ +# TODO: add __all__ with public API and remove star imports + del _init_helper diff --git a/mkl_umath/src/_patch_numpy.pyx b/mkl_umath/src/_patch_numpy.pyx index 6cd2990e..816218c3 100644 --- a/mkl_umath/src/_patch_numpy.pyx +++ b/mkl_umath/src/_patch_numpy.pyx @@ -26,6 +26,8 @@ # distutils: language = c # cython: language_level=3 +import warnings + from contextlib import ContextDecorator from threading import Lock, local @@ -124,7 +126,6 @@ cdef class _patch_impl: index = self.functions_dict[func] function = self.functions[index].patch_function signature = self.functions[index].signature - # TODO: check res, 0 means success, -1 means error res = cnp.PyUFunc_ReplaceLoopBySignature( np_umath, function, signature, &temp ) @@ -166,12 +167,11 @@ class _GlobalPatch: if self._patch_count == 0: if verbose: print( - "Now patching NumPy FFT submodule with mkl_fft NumPy " - "interface." + "Now patching NumPy ufuncs with mkl_umath loops." ) print( "Please direct bug reports to " - "https://github.com/IntelPython/mkl_fft" + "https://github.com/IntelPython/mkl_umath" ) if self._patcher is None: # lazy initialization of the patcher to save memory @@ -188,7 +188,7 @@ class _GlobalPatch: if verbose: print( "Warning: restore_numpy_umath called more times than " - "patch_numpy_fft in this thread." + "patch_numpy_umath in this thread." ) return self._tls.local_count -= 1 @@ -230,6 +230,20 @@ def patch_numpy_umath(verbose=False): will lead to undefined behavior at best, and segmentation faults at worst. For this reason, it is recommended to prefer the `mkl_umath` context manager. + + Examples + -------- + >>> import mkl_umath + >>> mkl_umath.is_patched() + # False + + >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # True + + >>> mkl_umath.restore() # Disable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # False """ _patch.do_patch(verbose=verbose) @@ -258,10 +272,54 @@ def restore_numpy_umath(verbose=False): will lead to undefined behavior at best, and segmentation faults at worst. For this reason, it is recommended to prefer the `mkl_umath` context manager. + + Examples + -------- + >>> import mkl_umath + >>> mkl_umath.is_patched() + # False + + >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # True + + >>> mkl_umath.restore() # Disable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # False """ _patch.do_restore(verbose=verbose) +def use_in_numpy(): + """ + Deprecated alias for patch_numpy_umath. + + See patch_numpy_umath for details and examples. + """ + warnings.warn( + "use_in_numpy is deprecated since mkl_random 0.4.0 and will be removed " + "in a future release. Use `patch_numpy_umath` instead.", + DeprecationWarning, + stacklevel=2, + ) + patch_numpy_umath() + + +def restore(): + """ + Deprecated alias for restore_numpy_umath. + + See restore_numpy_umath for details and examples. + """ + warnings.warn( + "restore is deprecated since mkl_random 0.4.0 and will be " + "removed in a future release. Use `restore_numpy_umath` instead.", + DeprecationWarning, + stacklevel=2, + ) + restore_numpy_umath() + + def is_patched(): """Return True if NumPy umath loops have been patched by mkl_umath.""" return _patch.is_patched() From 910011a3ebec64815f743dc788574a3e245139e0 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 10 Mar 2026 08:59:24 -0700 Subject: [PATCH 3/4] add deprecated patching aliases to __init__.py --- mkl_umath/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mkl_umath/__init__.py b/mkl_umath/__init__.py index bdff1c97..d0603992 100644 --- a/mkl_umath/__init__.py +++ b/mkl_umath/__init__.py @@ -33,7 +33,9 @@ is_patched, mkl_umath, patch_numpy_umath, + restore, restore_numpy_umath, + use_in_numpy, ) from ._ufuncs import * from ._version import __version__ From 711df0e0c16c1f6b3b7a34fdcb5a988ae795227f Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 10 Mar 2026 09:01:50 -0700 Subject: [PATCH 4/4] pre-commit fix --- mkl_umath/src/_patch_numpy.pyx | 1 - 1 file changed, 1 deletion(-) diff --git a/mkl_umath/src/_patch_numpy.pyx b/mkl_umath/src/_patch_numpy.pyx index 816218c3..7ef82f9f 100644 --- a/mkl_umath/src/_patch_numpy.pyx +++ b/mkl_umath/src/_patch_numpy.pyx @@ -27,7 +27,6 @@ # cython: language_level=3 import warnings - from contextlib import ContextDecorator from threading import Lock, local