对数组 API 标准的支持#

注意

数组 API 标准支持仍处于实验阶段,并通过环境变量隐藏。目前只有一小部分公共 API 得到支持。

本指南描述了如何使用添加对Python 数组 API 标准的支持。该标准允许用户开箱即用地将任何与数组 API 兼容的数组库与 SciPy 的部分功能一起使用。

The RFC 定义了 SciPy 如何实现对该标准的支持,其主要原则是“输入数组类型等于输出数组类型”。此外,该实现对允许的类数组输入执行更严格的验证,例如拒绝 numpy matrix 和 masked array 实例,以及具有 object dtype 的数组。

在下文中,与数组 API 兼容的命名空间表示为 xp

使用数组 API 标准支持#

要启用数组 API 标准支持,必须在导入 SciPy 之前设置一个环境变量

export SCIPY_ARRAY_API=1

这既启用了数组 API 标准支持,也为类数组参数提供了更严格的输入验证。请注意,此环境变量旨在作为临时措施,用于进行增量更改并将其合并到 main 中,而不会立即影响向后兼容性。我们不打算长期保留此环境变量。

此聚类示例展示了将 PyTorch 张量作为输入和返回值的使用方法

>>> import torch
>>> from scipy.cluster.vq import vq
>>> code_book = torch.tensor([[1., 1., 1.],
...                           [2., 2., 2.]])
>>> features  = torch.tensor([[1.9, 2.3, 1.7],
...                           [1.5, 2.5, 2.2],
...                           [0.8, 0.6, 1.7]])
>>> code, dist = vq(features, code_book)
>>> code
tensor([1, 1, 0], dtype=torch.int32)
>>> dist
tensor([0.4359, 0.7348, 0.8307])

请注意,上述示例适用于 PyTorch CPU 张量。对于 GPU 张量或 CuPy 数组,vq 的预期结果是 TypeError,因为 vq 在其实现中使用了编译代码,而编译代码在 GPU 上无法工作。

更严格的数组输入验证将拒绝 np.matrixnp.ma.MaskedArray 实例,以及具有 object dtype 的数组

>>> import numpy as np
>>> from scipy.cluster.vq import vq
>>> code_book = np.array([[1., 1., 1.],
...                       [2., 2., 2.]])
>>> features  = np.array([[1.9, 2.3, 1.7],
...                       [1.5, 2.5, 2.2],
...                       [0.8, 0.6, 1.7]])
>>> vq(features, code_book)
(array([1, 1, 0], dtype=int32), array([0.43588989, 0.73484692, 0.83066239]))

>>> # The above uses numpy arrays; trying to use np.matrix instances or object
>>> # arrays instead will yield an exception with `SCIPY_ARRAY_API=1`:
>>> vq(np.asmatrix(features), code_book)
...
TypeError: 'numpy.matrix' are not supported

>>> vq(np.ma.asarray(features), code_book)
...
TypeError: 'numpy.ma.MaskedArray' are not supported

>>> vq(features.astype(np.object_), code_book)
...
TypeError: object arrays are not supported

当前支持的功能#

当设置环境变量时,以下模块提供数组 API 标准支持

上述模块中的单个函数在文档中提供了功能表,如下所示。如果表中没有,则该函数尚不支持 NumPy 以外的后端。

示例功能表#

CPU

GPU

NumPy

不适用

CuPy

不适用

PyTorch

JAX

⚠️ 无 JIT

Dask

不适用

在上面的示例中,该功能对 NumPy、CuPy、PyTorch 和 JAX 数组提供了一定的支持,但不支持 Dask 数组。某些后端,如 JAX 和 PyTorch,原生支持多种设备(CPU 和 GPU),但 SciPy 对此类数组的支持可能有限;例如,此 SciPy 功能预计仅适用于位于 CPU 上的 JAX 数组。此外,某些后端可能存在主要注意事项;在示例中,该函数在 jax.jit 中运行时将失败。其他注意事项可能会在函数的 docstring 中列出。

虽然表中标记为“不适用”的元素本质上超出了范围,但我们仍在不断努力完善其余部分。Dask 封装 NumPy 以外的后端(特别是 CuPy)目前超出范围,但未来可能会改变。

请查看跟踪问题以获取更新。

实现说明#

对数组 API 标准的支持以及 NumPy、CuPy 和 PyTorch 的特定兼容性函数的关键部分是通过 array-api-compat 提供的。此包通过 git 子模块(在 scipy/_lib 下)包含在 SciPy 代码库中,因此没有引入新的依赖项。

array-api-compat 提供了通用实用函数并添加了别名,例如 xp.concat(在 NumPy 2.0 中添加 np.concat 之前,对于 numpy,它映射到 np.concatenate)。这允许在 NumPy、PyTorch、CuPy 和 JAX 之间使用统一的 API(其他库,如 Dask,正在开发中)。

当未设置环境变量从而禁用 SciPy 中的数组 API 标准支持时,我们仍然使用 NumPy 命名空间的封装版本,即 array_api_compat.numpy。这不应改变 SciPy 函数的行为,因为它实际上是现有的 numpy 命名空间,并添加了一些别名和少量为数组 API 标准支持而修改/添加的函数。当支持启用时,xp = array_namespace(input) 将是与输入数组类型匹配函数的标准兼容命名空间(例如,如果 cluster.vq.kmeans 的输入是 PyTorch 张量,则 xparray_api_compat.torch)。

为 SciPy 函数添加数组 API 标准支持#

尽可能地,添加到 SciPy 的新代码应尽量遵循数组 API 标准(这些函数通常也是 NumPy 用法的最佳实践习惯用法)。通过遵循该标准,有效添加对数组 API 标准的支持通常很简单,我们理想情况下不需要维护任何自定义。

各种辅助函数可在 scipy._lib._array_api 中获得——请参阅该模块中的 __all__ 以获取当前辅助函数的列表,并查看其 docstrings 以获取更多信息。

要为在 .py 文件中定义的 SciPy 函数添加支持,您需要更改的是:

  1. 输入数组验证,

  2. 使用 xp 而不是 np 函数,

  3. 在调用编译代码时,先将数组转换为 NumPy 数组,然后将其转换回输入数组类型。

输入数组验证使用以下模式:

xp = array_namespace(arr) # where arr is the input array
# alternatively, if there are multiple array inputs, include them all:
xp = array_namespace(arr1, arr2)

# replace np.asarray with xp.asarray
arr = xp.asarray(arr)
# uses of non-standard parameters of np.asarray can be replaced with _asarray
arr = _asarray(arr, order='C', dtype=xp.float64, xp=xp)

请注意,如果一个输入是非 NumPy 数组类型,则所有类数组输入都必须是该类型;尝试将非 NumPy 数组与列表、Python 标量或其他任意 Python 对象混合将引发异常。对于 NumPy 数组,出于向后兼容性原因,这些类型将继续被接受。

如果一个函数只调用一次编译代码,请使用以下模式:

x = np.asarray(x)  # convert to numpy right before compiled call(s)
y = _call_compiled_code(x)
y = xp.asarray(y)  # convert back to original array type

如果有多次调用编译代码,请确保只进行一次转换以避免过多的开销。

这是一个假设的公共 SciPy 函数 toto 的示例:

def toto(a, b):
    a = np.asarray(a)
    b = np.asarray(b, copy=True)

    c = np.sum(a) - np.prod(b)

    # this is some C or Cython call
    d = cdist(c)

    return d

您将这样进行转换:

def toto(a, b):
    xp = array_namespace(a, b)
    a = xp.asarray(a)
    b = xp_copy(b, xp=xp)  # our custom helper is needed for copy

    c = xp.sum(a) - xp.prod(b)

    # this is some C or Cython call
    c = np.asarray(c)
    d = cdist(c)
    d = xp.asarray(d)

    return d

经过编译代码需要回到 NumPy 数组,因为 SciPy 的扩展模块只适用于 NumPy 数组(或者在 Cython 的情况下是 memoryview)。对于 CPU 上的数组,转换应该是零拷贝的,而在 GPU 和其他设备上,尝试转换将引发异常。其原因是,设备之间隐式的数据传输被认为是不良实践,因为它很可能是一个大型且难以检测的性能瓶颈。

添加测试#

要在多个数组后端上运行测试,您应该向其添加 xp fixture,该 fixture 的值是当前测试的数组命名空间。

以下 pytest 标记可用:

  • skip_xp_backends(backend=None, reason=None, np_only=False, cpu_only=False, eager_only=False, exceptions=None): 跳过某些后端或后端类别。有关如何使用此标记跳过测试的信息,请参阅 scipy.conftest.skip_or_xfail_xp_backends 的 docstring。

  • xfail_xp_backends(backend=None, reason=None, np_only=False, cpu_only=False, eager_only=False, exceptions=None): xfail 某些后端或后端类别。有关如何使用此标记 xfail 测试的信息,请参阅 scipy.conftest.skip_or_xfail_xp_backends 的 docstring。

  • skip_xp_invalid_arg 用于跳过在使用 SCIPY_ARRAY_API 启用时无效的参数的测试。例如,scipy.stats 函数的一些测试将 masked array 传递给被测试的函数,但 masked array 与数组 API 不兼容。使用 skip_xp_invalid_arg 装饰器允许这些测试在不使用 SCIPY_ARRAY_API 时防止回归,同时在使用 SCIPY_ARRAY_API 时不会导致失败。随着时间的推移,我们希望这些函数在接收到数组 API 无效输入时发出弃用警告,并且此装饰器将检查是否发出了弃用警告而不会导致测试失败。当 SCIPY_ARRAY_API=1 行为成为默认且唯一的行为时,这些测试(以及装饰器本身)将被删除。

  • array_api_backends: 此标记由 xp fixture 自动添加到所有使用它的测试中。这对于例如选择所有且仅此类测试很有用。

    python dev.py test -b all -m array_api_backends
    

scipy._lib._array_api 包含与数组无关的断言,例如 xp_assert_close,可用于替换来自 numpy.testing 的断言。

当这些断言在使用 xp fixture 的测试中执行时,它们会强制实际数组和预期数组的命名空间与 fixture 设置的命名空间匹配。没有 xp fixture 的测试从预期数组推断命名空间。可以通过将 xp= 参数显式传递给断言函数来覆盖此机制。

以下示例演示如何使用这些标记:

from scipy.conftest import skip_xp_invalid_arg
from scipy._lib._array_api import xp_assert_close
...
@pytest.mark.skip_xp_backends(np_only=True, reason='skip reason')
def test_toto1(self, xp):
    a = xp.asarray([1, 2, 3])
    b = xp.asarray([0, 2, 5])
    xp_assert_close(toto(a, b), a)
...
@pytest.mark.skip_xp_backends('array_api_strict', reason='skip reason 1')
@pytest.mark.skip_xp_backends('cupy', reason='skip reason 2')
def test_toto2(self, xp):
    ...
...
# Do not run when SCIPY_ARRAY_API is used
@skip_xp_invalid_arg
def test_toto_masked_array(self):
    ...

将后端名称传递给 exceptions 意味着它们不会被 cpu_only=Trueeager_only=True 跳过。当某些(但不是所有)非 CPU 后端实现了委托,并且 CPU 代码路径需要转换为 NumPy 以用于编译代码时,这很有用。

# array-api-strict and CuPy will always be skipped, for the given reasons.
# All libraries using a non-CPU device will also be skipped, apart from
# JAX, for which delegation is implemented (hence non-CPU execution is supported).
@pytest.mark.skip_xp_backends(cpu_only=True, exceptions=['jax.numpy'])
@pytest.mark.skip_xp_backends('array_api_strict', reason='skip reason 1')
@pytest.mark.skip_xp_backends('cupy', reason='skip reason 2')
def test_toto(self, xp):
    ...

应用这些标记后,可以使用新选项 -b--array-api-backend 来使用 dev.py test

python dev.py test -b numpy -b torch -s cluster

这会自动适当地设置 SCIPY_ARRAY_API。要使用非默认设备测试具有多个设备的库,可以设置第二个环境变量(SCIPY_DEVICE,仅在测试套件中使用)。有效值取决于所测试的数组库,例如,对于 PyTorch,有效值为 "cpu", "cuda", "mps"。要使用 PyTorch MPS 后端运行测试套件,请使用:SCIPY_DEVICE=mps python dev.py test -b torch

请注意,有一个 GitHub Actions 工作流在 CPU 上使用 array-api-strict、PyTorch 和 JAX 进行测试。

测试 JAX JIT 编译器#

JAX JIT 编译器 对所有由 @jax.jit 封装的代码引入了特殊限制,这些限制在急切模式下运行 JAX 时不存在。值得注意的是,不支持 __getitem__at 中的布尔掩码,并且您不能通过对其应用 bool()float()np.asarray() 等来具体化数组。

为了正确测试带有 JAX 的 SciPy,您需要在单元测试调用它们之前,使用 @jax.jit 包装被测试的 SciPy 函数。要实现这一点,您应该在测试模块中按如下方式标记它们:

from scipy._lib._lazy_testing import lazy_xp_function
from scipy.mymodule import toto

lazy_xp_function(toto)

def test_toto(xp):
    a = xp.asarray([1, 2, 3])
    b = xp.asarray([0, 2, 5])
    # When xp==jax.numpy, toto is wrapped with @jax.jit
    xp_assert_close(toto(a, b), a)

请参阅 scipy/_lib/_lazy_testing.py 中的完整文档。

附加信息#

以下是一些在开发阶段启发了设计决策并提供了帮助的额外资源:

  • 包含一些讨论的初始 PR

  • 从此 PR 快速开始,并从 scikit-learn 中获得了一些启发。

  • 为 scikit-learn 添加数组 API 支持的 PR

  • 其他相关的 scikit-learn PR:#22554#25956