支持 Array API 标准#

注意

Array API 标准支持仍处于实验阶段,且隐藏在环境变量之后。目前仅覆盖了公共 API 的一小部分。

本指南介绍了如何 使用 以及 添加对 Python array API 标准 的支持。该标准允许用户在 SciPy 的部分功能中直接使用任何兼容 array API 的数组库。

RFC 定义了 SciPy 如何实现对该标准的支持,其核心原则是 “输入为何种数组类型,输出即为何种数组类型”。此外,该实现对允许的类数组(array-like)输入进行更严格的验证,例如:拒绝 NumPy 矩阵(matrix)和掩码数组(masked array)实例,以及对象(object)数据类型的数组。

在下文中,兼容 array API 的命名空间表示为 xp

使用 Array API 标准支持#

要启用 Array API 标准支持,必须在导入 SciPy 之前设置环境变量

export SCIPY_ARRAY_API=1

这既启用了 Array API 标准支持,也启用了对类数组参数更严格的输入验证。请注意,该环境变量旨在作为临时手段,以便在不立即影响向后兼容性的情况下,进行增量更改并将其合并到 ``main`` 分支。我们不打算长期保留此环境变量。

这个聚类示例展示了将 PyTorch 张量作为输入和返回值的用法

>>> import torch
>>> from scipy.cluster.vq import vq
>>> code_book = torch.tensor([[1., 1., 1.],
...                           [2., 2., 2.]])
>>> features  = torch.tensor([[1.9, 2.3, 1.7],
...                           [1.5, 2.5, 2.2],
...                           [0.8, 0.6, 1.7]])
>>> code, dist = vq(features, code_book)
>>> code
tensor([1, 1, 0], dtype=torch.int32)
>>> dist
tensor([0.4359, 0.7348, 0.8307])

请注意,上述示例适用于 PyTorch CPU 张量。对于 GPU 张量或 CuPy 数组,vq 预期会返回 TypeError,因为 vq 在其实现中使用了编译代码,而这些代码无法在 GPU 上运行。

更严格的数组输入验证将拒绝 np.matrixnp.ma.MaskedArray 实例,以及 object 数据类型的数组

>>> import numpy as np
>>> from scipy.cluster.vq import vq
>>> code_book = np.array([[1., 1., 1.],
...                       [2., 2., 2.]])
>>> features  = np.array([[1.9, 2.3, 1.7],
...                       [1.5, 2.5, 2.2],
...                       [0.8, 0.6, 1.7]])
>>> vq(features, code_book)
(array([1, 1, 0], dtype=int32), array([0.43588989, 0.73484692, 0.83066239]))

>>> # The above uses numpy arrays; trying to use np.matrix instances or object
>>> # arrays instead will yield an exception with `SCIPY_ARRAY_API=1`:
>>> vq(np.asmatrix(features), code_book)
...
TypeError: 'numpy.matrix' are not supported

>>> vq(np.ma.asarray(features), code_book)
...
TypeError: 'numpy.ma.MaskedArray' are not supported

>>> vq(features.astype(np.object_), code_book)
...
TypeError: object arrays are not supported

功能示例表#

CPU

GPU

NumPy

不适用

CuPy

不适用

PyTorch

JAX

⚠️ 无 JIT

Dask

不适用

在上面的示例中,该特性对 NumPy、CuPy、PyTorch 和 JAX 数组有一定的支持,但不支持 Dask 数组。某些后端(如 JAX 和 PyTorch)原生支持多设备(CPU 和 GPU),但 SciPy 对此类数组的支持可能有限;例如,此 SciPy 特性预计仅适用于位于 CPU 上的 JAX 数组。此外,某些后端可能存在重大注意事项;在示例中,该函数在 jax.jit 内部运行时将失败。其他注意事项可能会在函数的文档字符串(docstring)中列出。

虽然表中标记为“n/a”的元素本质上不在讨论范围内,但我们一直在努力完善其余部分。目前,Dask 封装 NumPy 以外的后端(特别是 CuPy)尚不在讨论范围内,但未来可能会发生变化。

请查看 追踪问题(tracker issue) 以获取更新。

实现说明#

对 Array API 标准的支持以及针对 NumPy、CuPy 和 PyTorch 的特定兼容性功能的关键部分是通过 array-api-compat 提供的。该软件包通过 git 子模块(位于 scipy/_lib 下)包含在 SciPy 代码库中,因此不会引入新的依赖项。

array-api-compat 提供了通用的工具函数,并添加了别名,例如 xp.concat(在 NumPy 2.0 添加 np.concat 之前,对于 NumPy,它映射到 np.concatenate)。这使得在 NumPy、PyTorch、CuPy 和 JAX 之间使用统一的 API 成为可能(其他库如 Dask 正在开发中)。

当未设置环境变量且 SciPy 中的 Array API 标准支持被禁用时,我们仍然使用封装版本的 NumPy 命名空间,即 array_api_compat.numpy。这不应改变 SciPy 函数的行为,因为它实际上就是现有的 numpy 命名空间,只是添加了一些别名,并为支持 Array API 标准修改/添加了少量函数。当启用支持时,xp = array_namespace(input) 将是与函数输入数组类型匹配的符合标准的命名空间(例如,如果 cluster.vq.kmeans 的输入是 PyTorch 张量,那么 xp 就是 array_api_compat.torch)。

向 SciPy 函数添加 Array API 标准支持#

添加到 SciPy 的新代码应尽可能严格遵循 Array API 标准(这些函数通常也是 NumPy 使用的最佳实践)。通过遵循该标准,添加对 Array API 标准的支持通常非常直接,理想情况下我们不需要维护任何自定义代码。

scipy._lib._array_api 中提供了各种辅助函数 —— 请参阅该模块中的 __all__ 列表以获取当前的辅助函数,并参阅其文档字符串以获取更多信息。

要为定义在 .py 文件中的 SciPy 函数添加支持,你需要更改的是:

  1. 输入数组验证,

  2. 使用 xp 而不是 np 函数,

  3. 当调用编译代码时,在调用前将数组转换为 NumPy 数组,在调用后将其转换回输入数组类型。

输入数组验证使用以下模式

xp = array_namespace(arr) # where arr is the input array
# alternatively, if there are multiple array inputs, include them all:
xp = array_namespace(arr1, arr2)

# replace np.asarray with xp.asarray
arr = xp.asarray(arr)
# uses of non-standard parameters of np.asarray can be replaced with _asarray
arr = _asarray(arr, order='C', dtype=xp.float64, xp=xp)

请注意,如果一个输入是非 NumPy 数组类型,则所有类数组输入都必须是该类型;尝试将非 NumPy 数组与列表、Python 标量或其他任意 Python 对象混合将引发异常。对于 NumPy 数组,出于向后兼容性的原因,将继续接受这些类型。

如果一个函数只调用一次编译代码,请使用以下模式

x = np.asarray(x)  # convert to numpy right before compiled call(s)
y = _call_compiled_code(x)
y = xp.asarray(y)  # convert back to original array type

如果多次调用编译代码,请确保仅执行一次转换以避免过多的开销。

这是一个假设的 SciPy 公共函数 toto 的示例

def toto(a, b):
    a = np.asarray(a)
    b = np.asarray(b, copy=True)

    c = np.sum(a) - np.prod(b)

    # this is some C or Cython call
    d = cdist(c)

    return d

你可以像这样转换它

def toto(a, b):
    xp = array_namespace(a, b)
    a = xp.asarray(a)
    b = xp_copy(b, xp=xp)  # our custom helper is needed for copy

    c = xp.sum(a) - xp.prod(b)

    # this is some C or Cython call
    c = np.asarray(c)
    d = cdist(c)
    d = xp.asarray(d)

    return d

通过编译代码需要转换回 NumPy 数组,因为 SciPy 的扩展模块仅支持 NumPy 数组(或 Cython 中的 memoryview)。对于 CPU 上的数组,转换应该是零拷贝的,而在 GPU 和其他设备上,转换尝试将引发异常。原因是设备之间静默的数据传输被认为是糟糕的做法,因为它很可能成为巨大且难以检测的性能瓶颈。

添加测试#

要在多个数组后端上运行测试,你应该向其添加 xp fixture,其值为当前测试的数组命名空间。

可以使用以下 pytest marker:

  • skip_xp_backends(backend=None, reason=None, np_only=False, cpu_only=False, eager_only=False, exceptions=None):跳过某些后端或后端类别。有关如何使用此 marker 跳过测试的信息,请参阅 scipy.conftest.skip_or_xfail_xp_backends 的文档字符串。

  • xfail_xp_backends(backend=None, reason=None, np_only=False, cpu_only=False, eager_only=False, exceptions=None):将某些后端或后端类别标记为预期失败(xfail)。有关如何使用此 marker 将测试标记为 xfail 的信息,请参阅 scipy.conftest.skip_or_xfail_xp_backends 的文档字符串。

  • skip_xp_invalid_arg 用于跳过那些在启用 SCIPY_ARRAY_API 时使用无效参数的测试。例如,一些 scipy.stats 函数的测试会将掩码数组传递给被测函数,但掩码数组与 Array API 不兼容。使用 skip_xp_invalid_arg 装饰器允许这些测试在不使用 SCIPY_ARRAY_API 时防止回归,同时在启用 SCIPY_ARRAY_API 时不会导致失败。将来,我们希望这些函数在接收到 Array API 无效输入时发出弃用警告,而此装饰器将检查是否发出了弃用警告且不导致测试失败。当 SCIPY_ARRAY_API=1 的行为成为默认且唯一的行为时,这些测试(以及装饰器本身)将被移除。

  • array_api_backends:此 marker 由 xp fixture 自动添加到所有使用它的测试中。这对于选择所有且仅选择此类测试非常有用。

    spin test -b all -m array_api_backends
    

scipy._lib._array_api 包含与数组无关的断言,例如 xp_assert_close,可用于替换来自 numpy.testing 的断言。

当在使用了 xp fixture 的测试中执行这些断言时,它们会强制要求实际值和期望值数组的命名空间都必须与 fixture 设置的命名空间相匹配。没有 xp fixture 的测试会从期望数组中推断命名空间。可以通过显式地将 xp= 参数传递给断言函数来覆盖此机制。

以下示例演示了如何使用这些 marker

from scipy.conftest import skip_xp_invalid_arg
from scipy._lib._array_api import xp_assert_close
...
@pytest.mark.skip_xp_backends(np_only=True, reason='skip reason')
def test_toto1(self, xp):
    a = xp.asarray([1, 2, 3])
    b = xp.asarray([0, 2, 5])
    xp_assert_close(toto(a, b), a)
...
@pytest.mark.skip_xp_backends('array_api_strict', reason='skip reason 1')
@pytest.mark.skip_xp_backends('cupy', reason='skip reason 2')
def test_toto2(self, xp):
    ...
...
# Do not run when SCIPY_ARRAY_API is used
@skip_xp_invalid_arg
def test_toto_masked_array(self):
    ...

将后端名称传递给 exceptions 意味着它们不会被 cpu_only=Trueeager_only=True 跳过。当为某些(但非全部)非 CPU 后端实现了委托,且 CPU 代码路径需要为编译代码转换为 NumPy 时,这很有用。

# array-api-strict and CuPy will always be skipped, for the given reasons.
# All libraries using a non-CPU device will also be skipped, apart from
# JAX, for which delegation is implemented (hence non-CPU execution is supported).
@pytest.mark.skip_xp_backends(cpu_only=True, exceptions=['jax.numpy'])
@pytest.mark.skip_xp_backends('array_api_strict', reason='skip reason 1')
@pytest.mark.skip_xp_backends('cupy', reason='skip reason 2')
def test_toto(self, xp):
    ...

应用这些 marker 后,spin test 可以使用新选项 -b--array-api-backend

spin test -b numpy -b torch -s cluster

这会自动相应地设置 SCIPY_ARRAY_API。要测试具有多个设备且使用非默认设备的库,可以设置第二个环境变量(SCIPY_DEVICE,仅在测试套件中使用)。有效值取决于正在测试的数组库,例如对于 PyTorch,有效值为 "cpu", "cuda", "mps"。要使用 PyTorch MPS 后端运行测试套件,请使用:SCIPY_DEVICE=mps spin test -b torch

请注意,GitHub Actions 工作流中包含在 CPU 上使用 array-api-strict、PyTorch 和 JAX 进行测试的内容。

测试实践#

对于任何受支持的函数 f,存在使用 xp fixture 的测试是很重要的,这些测试将备选后端的使用仅限制在被测函数 f 上。在测试中为了生成参考值、输入、往返计算等而调用的其他函数,应改用 NumPy 后端。这有助于确保后端上发生的任何失败确实与感兴趣的函数有关,并避免了因缺乏对 f 以外函数支持而不得不跳过后端的情况。检查不同函数间使用相同备选后端是否保持某种不变性的基于属性的集成测试也很有价值,可以让我们了解模块对后端支持的总体健康状况。但为了确保测试套件确实反映了每个函数对后端支持的状态,拥有将被选后端的使用隔离在被测函数上的测试至关重要。

为了帮助实现这种后端隔离,在 scipy._lib._array_api 中有一个函数 _xp_copy_to_numpy,它可以将任意 xp 数组复制到 NumPy 数组,绕过任何设备传输防护,同时保留数据类型。此函数只能在针对被测函数以外的函数的测试中使用,这一点至关重要。尝试在测试之外将设备数组复制到 NumPy 应当失败,否则一个函数是否在 GPU 上运行将变得不透明。

当尝试将备选后端的使用隔离到特定函数时,必须注意 PyTorch 允许设置默认数据类型,并且 SciPy 会在默认数据类型为 float32float64 的情况下分别进行测试(这由环境变量 SCIPY_DEFAULT_DTYPE 控制)。使用 xp fixture 的测试依赖于 xp.asarray 在给定列表输入且未指定显式数据类型时生成具有默认数据类型的数组。这意味着,如果测试涉及获取输入数组并将其传递给被测函数以外的函数以生成被测函数的输入,以下写法可能看起来很自然,但不会产生正确的数据类型行为

# z, p, k will have dtype float64 regardless of the value of
# SCIPY_DEFAULT_DTYPE
z = np.asarray([1j, -1j, 2j, -2j])
p = np.asarray([1+1j, 3-100j, 3+100j, 1-1j])
k = 23

# np.poly will preserve dtype
b = k * np.poly(z_np).real
a = np.poly(p_np).real
# Input arrays z, p, and reference outputs b, a will all have
# dtype float64.
z, p, b, a = map(xp.asarray, (z, p, b, a))

# With float64 inputs, the outputs bp and ap will be of dtype
# float64. Note that the parameter k is a Python scalar which does
# not impact output dtype for NumPy >= 2.0.
bp, ap = zpk2tf(z, p, k)
# xp_assert_close checks for matching dtype. Due to the way the
# code was written above, zpk2tf is not tested with float32 inputs
# when SCIPY_DEFAULT_DTYPE is float32.
xp_assert_close(b, bp)
xp_assert_close(a, ap)

相反,可以将所有输入构造为 xp 数组,然后复制到 NumPy 数组,以确保遵循默认数据类型

# calls to xp.asarray will respect the default dtype.
z = xp.asarray([1j, -1j, 2j, -2j])
p = xp.asarray([1+1j, 3-100j, 3+100j, 1-1j])
k = 23

# _xp_copy_to_numpy preserves dtype, as does np.poly.
b = k * np.poly(_xp_copy_to_numpy(z)).real
a = np.poly(_xp_copy_to_numpy(p)).real
# b and a will have dtype float32
b, a = map(xp.asarray, (b, a))

# zpk2tf is tested with float32 inputs when SCIPY_DEFAULT_DTYPE=float32
# as intended.
bp, ap = zpk2tf(z, p, k)
xp_assert_close(b, bp)
xp_assert_close(a, ap)

测试 JAX JIT 编译器#

JAX JIT 编译器 对所有被 @jax.jit 封装的代码引入了特殊限制,这些限制在以 eager 模式运行 JAX 时不存在。值得注意的是,不支持在 __getitem__at 中使用布尔掩码,并且不能通过对数组应用 bool()float()np.asarray() 等来使其具体化(materialize)。

为了用 JAX 正确测试 scipy,你需要在单元测试调用被测 scipy 函数之前,用 @jax.jit 对其进行封装。为此,你应该在测试模块中按如下方式对其进行标记

from scipy._lib.array_api_extra.testing import lazy_xp_function
from scipy.mymodule import toto

lazy_xp_function(toto)

def test_toto(xp):
    a = xp.asarray([1, 2, 3])
    b = xp.asarray([0, 2, 5])
    # When xp==jax.numpy, toto is wrapped with @jax.jit
    xp_assert_close(toto(a, b), a)

参阅完整文档 此处

附加信息#

这里有一些额外的资源,它们促成了某些设计决策并在开发阶段提供了帮助

  • 带有讨论的初始 PR

  • 从这个 PR 快速开始,并从 scikit-learn 获得了一些启发。

  • 为 scikit-learn 添加 Array API 支持的 PR

  • 其他一些相关的 scikit-learn PR:#22554#25956

API 覆盖范围#

下面的表格显示了 SciPy 各模块对备选后端的支持现状。目前表格中仅包含公共函数和类函数的可调用对象,但计划最终也将相关的公共类纳入其中。被认为超出范围的函数不予考虑。如果一个模块或子模块不包含任何在范围内的函数,则将其从表格中排除。例如,scipy.spatial.transform 目前被排除在外,因为它的 API 不包含函数,但将来范围扩大到包含类时可能会被纳入。scipy.odrscipy.datasets 被排除在外,因为它们的内容被认为超出范围。

目前还没有关于哪些函数应被视为超出备选后端支持范围的正式政策。遵循的一些通用经验法则是排除以下内容:

  • 不操作数组的函数,如 scipy.constants.value

  • 过于依赖具体实现的函数,例如 scipy.linalg.blas 中的函数,它们提供了底层 BLAS 例程的直接封装。

  • 在加速计算设备上执行本质上非常困难甚至不可能高效计算的函数。

举个例子。scipy.odr 的内容由于上述原因 2 和 3 的结合被视为超出范围。scipy.odr 本质上提供了对单体式 ODRPACK Fortran 库的直接封装,其 API 与该单体库的结构绑定。创建一个高效的 GPU 加速非线性加权正交距离回归实现本身也是一个具有挑战性的问题。尽管如此,关于什么是范围内内容的考虑正在演变,如果表现出足够的用户兴趣和可行性,现在被认为超出范围的内容将来可能会被判定为在范围内。

注意

下面显示的覆盖百分比可能低于真实值,因为在开发注册此类支持的基础设施之前,就已经为某些函数添加了备选后端支持。这种情况通过在百分比旁边放置星号来表示。备选后端支持的文档目前正在完善中。

CPU 支持#

模块

torch

jax

dask

cluster.vq (4)

100%

100%

100%

cluster.hierarchy (29)

97%

97%

97%

constants (3)

100%

100%

100%

differentiate (3)

100%

100%

0%

fft (32)

100%

94%

100%

integrate (19)

37%

21%

26%

interpolate (14)

43%*

43%*

43%*

io (9)

0%*

0%*

0%*

linalg (95)

3%*

3%*

2%*

linalg.interpolative (9)

0%*

0%*

0%*

ndimage (73)

100%

100%

100%

optimize (57)

7%*

4%*

7%*

optimize.elementwise (4)

75%

0%

0%

signal (140)

64%

57%

61%

signal.windows (26)

96%

88%

92%

sparse (35)

0%*

0%*

0%*

sparse.linalg (32)

0%*

0%*

0%*

sparse.csgraph (25)

0%*

0%*

0%*

spatial (9)

0%*

0%*

0%*

spatial.distance (27)

0%*

0%*

0%*

special (340)

28%*

28%*

28%*

stats (133)

54%

50%

35%

stats.contingency (7)

0%*

0%*

0%*

stats.qmc (4)

0%*

0%*

0%*

GPU 支持#

模块

cupy

torch

jax

cluster.vq (4)

25%

25%

25%

cluster.hierarchy (29)

28%

28%

28%

constants (3)

100%

100%

100%

differentiate (3)

100%

100%

100%

fft (32)

75%

75%

75%

integrate (19)

42%

37%

21%

interpolate (14)

0%*

0%*

0%*

io (9)

0%*

0%*

0%*

linalg (95)

3%*

3%*

3%*

linalg.interpolative (9)

0%*

0%*

0%*

ndimage (73)

93%

0%

1%

optimize (57)

7%*

7%*

4%*

optimize.elementwise (4)

100%

75%

0%

signal (140)

69%

34%

20%

signal.windows (26)

96%

96%

88%

sparse (35)

0%*

0%*

0%*

sparse.linalg (32)

0%*

0%*

0%*

sparse.csgraph (25)

0%*

0%*

0%*

spatial (9)

0%*

0%*

0%*

spatial.distance (27)

0%*

0%*

0%*

special (340)

28%*

12%*

12%*

stats (133)

43%

41%

44%

stats.contingency (7)

0%*

0%*

0%*

stats.qmc (4)

0%*

0%*

0%*

JIT 支持#

模块

jax

cluster.vq (4)

25%

cluster.hierarchy (29)

79%

constants (3)

100%

differentiate (3)

0%

fft (32)

94%

integrate (19)

11%

interpolate (14)

0%*

io (9)

0%*

linalg (95)

1%*

linalg.interpolative (9)

0%*

ndimage (73)

1%

optimize (57)

4%*

optimize.elementwise (4)

0%

signal (140)

21%

signal.windows (26)

88%

sparse (35)

0%*

sparse.linalg (32)

0%*

sparse.csgraph (25)

0%*

spatial (9)

0%*

spatial.distance (27)

0%*

special (340)

12%*

stats (133)

28%

stats.contingency (7)

0%*

stats.qmc (4)

0%*