scipy.linalg.

solve#

scipy.linalg.solve(a, b, lower=False, overwrite_a=False, overwrite_b=False, check_finite=True, assume_a=None, transposed=False)[source]#

解方程 a @ x = b 中的 x,其中 a 是一个方阵。

如果已知数据矩阵的特定类型,则向 assume_a 键提供相应的字符串会选择专用的求解器。可用选项包括:

对角矩阵

‘diagonal’

三对角矩阵

‘tridiagonal’

带状矩阵

‘banded’

上三角矩阵

‘upper triangular’

下三角矩阵

‘lower triangular’

对称矩阵

‘symmetric’ (或 ‘sym’)

厄米特矩阵

‘hermitian’ (或 ‘her’)

对称正定矩阵

‘positive definite’ (或 ‘pos’)

一般矩阵

‘general’ (或 ‘gen’)

此函数的数组参数可以在核心形状前添加额外的“批次”维度。在这种情况下,数组被视为低维切片的批次;有关详细信息,请参阅 批次线性运算

参数:
aarray_like, shape (…, N, N)

左侧方阵或矩阵批次。

b(…, N, NRHS) array_like

右侧输入数据或右侧批次。

lowerbool, default: False

除非 assume_a'sym''her''pos' 之一,否则此参数将被忽略。如果为 True,计算仅使用 a 的下三角数据;对角线上方的条目将被忽略。如果为 False(默认值),计算仅使用 a 的上三角数据;对角线下方的条目将被忽略。

overwrite_abool, default: False

允许覆盖 a 中的数据(可能会提高性能)。

overwrite_bbool, default: False

允许覆盖 b 中的数据(可能会提高性能)。

check_finitebool, default: True

是否检查输入矩阵是否仅包含有限数。禁用检查可能会获得性能提升,但如果输入确实包含无穷大或 NaN,则可能会导致问题(崩溃、无法终止)。

assume_astr, optional

有效条目如上所述。如果省略或为 None,将执行检查以识别结构,以便调用适当的求解器。

transposedbool, default: False

如果为 True,则求解 a.T @ x == b。对于复数 a,会引发 NotImplementedError

返回:
xndarray, shape (N, NRHS) or (…, N)

解数组。

引发:
ValueError

如果检测到大小不匹配或输入 a 不是方阵。

LinAlgError

如果由于矩阵奇异性导致计算失败。

LinAlgWarning

如果检测到病态输入 a。

NotImplementedError

如果 transposed 为 True 且输入 a 是一个复数矩阵。

附注

如果输入的 b 矩阵是一个具有 N 个元素的 1-D 数组,并且与 NxN 输入 a 一起提供,尽管存在明显的大小不匹配,它仍被视为有效的列向量。这与 numpy.dot() 行为兼容,返回的结果仍然是 1-D 数组。

一般、对称、厄米特和正定解分别通过调用 LAPACK 的 ?GETRF/?GETRS、?SYSV、?HESV 和 ?POTRF/?POTRS 例程获得。

数组的数据类型决定了调用哪个求解器,而与值无关。换句话说,即使复数数组条目具有精确的零虚部,也会根据数组的数据类型调用复数求解器。

示例

给定 ab,求解 x

>>> import numpy as np
>>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
>>> b = np.array([2, 4, -1])
>>> from scipy.linalg import solve
>>> x = solve(a, b)
>>> x
array([ 2., -2.,  9.])
>>> a @ x == b
array([ True,  True,  True], dtype=bool)

支持矩阵批次,无论是否进行结构检测

>>> a = np.arange(12).reshape(3, 2, 2)   # a batch of 3 2x2 matrices
>>> A = a.transpose(0, 2, 1) @ a    # A is a batch of 3 positive definite matrices
>>> b = np.ones(2)
>>> solve(A, b)      # this automatically detects that A is pos.def.
array([[ 1. , -0.5],
       [ 3. , -2.5],
       [ 5. , -4.5]])
>>> solve(A, b, assume_a='pos')   # bypass structucture detection
array([[ 1. , -0.5],
       [ 3. , -2.5],
       [ 5. , -4.5]])