solve#
- scipy.linalg.solve(a, b, lower=False, overwrite_a=False, overwrite_b=False, check_finite=True, assume_a=None, transposed=False)[source]#
解方程
a @ x = b中的x,其中 a 是一个方阵。如果已知数据矩阵的特定类型,则向
assume_a键提供相应的字符串会选择专用的求解器。可用选项包括:对角矩阵
‘diagonal’
三对角矩阵
‘tridiagonal’
带状矩阵
‘banded’
上三角矩阵
‘upper triangular’
下三角矩阵
‘lower triangular’
对称矩阵
‘symmetric’ (或 ‘sym’)
厄米特矩阵
‘hermitian’ (或 ‘her’)
对称正定矩阵
‘positive definite’ (或 ‘pos’)
一般矩阵
‘general’ (或 ‘gen’)
此函数的数组参数可以在核心形状前添加额外的“批次”维度。在这种情况下,数组被视为低维切片的批次;有关详细信息,请参阅 批次线性运算。
- 参数:
- aarray_like, shape (…, N, N)
左侧方阵或矩阵批次。
- b(…, N, NRHS) array_like
右侧输入数据或右侧批次。
- lowerbool, default: False
除非
assume_a是'sym'、'her'或'pos'之一,否则此参数将被忽略。如果为 True,计算仅使用 a 的下三角数据;对角线上方的条目将被忽略。如果为 False(默认值),计算仅使用 a 的上三角数据;对角线下方的条目将被忽略。- overwrite_abool, default: False
允许覆盖 a 中的数据(可能会提高性能)。
- overwrite_bbool, default: False
允许覆盖 b 中的数据(可能会提高性能)。
- check_finitebool, default: True
是否检查输入矩阵是否仅包含有限数。禁用检查可能会获得性能提升,但如果输入确实包含无穷大或 NaN,则可能会导致问题(崩溃、无法终止)。
- assume_astr, optional
有效条目如上所述。如果省略或为
None,将执行检查以识别结构,以便调用适当的求解器。- transposedbool, default: False
如果为 True,则求解
a.T @ x == b。对于复数 a,会引发 NotImplementedError。
- 返回:
- xndarray, shape (N, NRHS) or (…, N)
解数组。
- 引发:
- ValueError
如果检测到大小不匹配或输入 a 不是方阵。
- LinAlgError
如果由于矩阵奇异性导致计算失败。
- LinAlgWarning
如果检测到病态输入 a。
- NotImplementedError
如果 transposed 为 True 且输入 a 是一个复数矩阵。
附注
如果输入的 b 矩阵是一个具有 N 个元素的 1-D 数组,并且与 NxN 输入 a 一起提供,尽管存在明显的大小不匹配,它仍被视为有效的列向量。这与 numpy.dot() 行为兼容,返回的结果仍然是 1-D 数组。
一般、对称、厄米特和正定解分别通过调用 LAPACK 的 ?GETRF/?GETRS、?SYSV、?HESV 和 ?POTRF/?POTRS 例程获得。
数组的数据类型决定了调用哪个求解器,而与值无关。换句话说,即使复数数组条目具有精确的零虚部,也会根据数组的数据类型调用复数求解器。
示例
给定 a 和 b,求解 x
>>> import numpy as np >>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]]) >>> b = np.array([2, 4, -1]) >>> from scipy.linalg import solve >>> x = solve(a, b) >>> x array([ 2., -2., 9.]) >>> a @ x == b array([ True, True, True], dtype=bool)
支持矩阵批次,无论是否进行结构检测
>>> a = np.arange(12).reshape(3, 2, 2) # a batch of 3 2x2 matrices >>> A = a.transpose(0, 2, 1) @ a # A is a batch of 3 positive definite matrices >>> b = np.ones(2) >>> solve(A, b) # this automatically detects that A is pos.def. array([[ 1. , -0.5], [ 3. , -2.5], [ 5. , -4.5]]) >>> solve(A, b, assume_a='pos') # bypass structucture detection array([[ 1. , -0.5], [ 3. , -2.5], [ 5. , -4.5]])