外推技巧和窍门#

scipy.interpolate 中,不同例程对插值器在插值数据域之外的查询点的评估(即外推)处理方式并不完全一致。不同的插值器使用不同的关键字参数集来控制数据域之外的行为:一些使用 extrapolate=True/False/None,一些允许 fill_value 关键字。有关每个特定插值例程的详细信息,请参阅 API 文档。

根据具体问题,可用的关键字可能不足。需要注意非线性插值的外推。通常情况下,随着距离数据域越来越远,外推结果的意义越来越小。当然,这是可以预期的:插值器只知道数据域内的数据。

当默认的外推结果不符合要求时,用户需要自己实现所需的外推模式。

在本教程中,我们考虑了几个实例,其中演示了使用可用关键字和手动实现所需外推模式。这些实例可能适用于也可能不适用于您的具体问题;它们不一定是最佳实践;它们故意简化到仅包含演示主要思想所需的必要内容,希望它们能激发您处理具体问题的灵感。

interp1d:复制 numpy.interp 的左填充值和右填充值#

简而言之:使用 fill_value=(left, right)

numpy.interp 使用常数外推,并默认在插值区间内扩展 y 数组的第一个和最后一个值:np.interp(xnew, x, y) 的输出为 y[0](对于 xnew < x[0])和 y[-1](对于 xnew > x[-1])。

默认情况下,interp1d 拒绝外推,并在对插值范围之外的数据点进行评估时引发 ValueError。这可以通过 bounds_error=False 参数来关闭:然后 interp1d 使用 fill_value 设置超出范围的值,默认情况下为 nan

为了模仿 numpy.interpinterp1d 的行为,您可以利用它支持将 2 元组用作 fill_value 的事实。然后,元组元素分别用于填充 xnew < min(x)x > max(x)。对于多维 y,这些元素必须与 y 的形状相同或可广播到该形状。

为了说明

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

x = np.linspace(0, 1.5*np.pi, 11)
y = np.column_stack((np.cos(x), np.sin(x)))   # y.shape is (11, 2)

func = interp1d(x, y,
                axis=0,  # interpolate along columns
                bounds_error=False,
                kind='linear',
                fill_value=(y[0], y[-1]))
xnew = np.linspace(-np.pi, 2.5*np.pi, 51)
ynew = func(xnew)

fix, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.plot(xnew, ynew[:, 0])
ax1.plot(x, y[:, 0], 'o')

ax2.plot(xnew, ynew[:, 1])
ax2.plot(x, y[:, 1], 'o')
plt.tight_layout()
../../_images/extrapolation_examples-1.png

CubicSpline 扩展边界条件#

CubicSpline 需要两个额外的边界条件,这些条件由 bc_type 参数控制。此参数可以列出边缘处导数的显式值,也可以使用有用的别名。例如,bc_type="clamped" 将一阶导数设置为零,bc_type="natural" 将二阶导数设置为零(另外两个识别的字符串值为“periodic”和“not-a-knot”)

虽然外推由边界条件控制,但它们之间的关系并不直观。例如,人们可能会期望对于 bc_type="natural",外推是线性的。这种期望过于强烈:每个边界条件只设置一个点的导数,仅在边界处。外推是从第一个和最后一个多项式片段进行的,对于自然样条曲线来说,这是一个在给定点具有零二阶导数的三次曲线。

从另一个角度来看,这种期望过于强烈的原因是,考虑一个只有三个数据点的集合,其中样条曲线有两个多项式片段。为了线性外推,这种期望意味着这两个片段都是线性的。但是,然后,两个线性片段无法在具有连续二阶导数的中点处匹配!(除非当然,如果所有三个数据点实际上都位于一条直线上)。

为了说明这种行为,我们考虑一个合成数据集并比较几种边界条件

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline

xs = [1, 2, 3, 4, 5, 6, 7, 8]
ys = [4.5, 3.6, 1.6, 0.0, -3.3, -3.1, -1.8, -1.7]

notaknot = CubicSpline(xs, ys, bc_type='not-a-knot')
natural = CubicSpline(xs, ys, bc_type='natural')
clamped = CubicSpline(xs, ys, bc_type='clamped')
xnew = np.linspace(min(xs) - 4, max(xs) + 4, 101)

splines = [notaknot, natural, clamped]
titles = ['not-a-knot', 'natural', 'clamped']

fig, axs = plt.subplots(3, 3, figsize=(12, 12))
for i in [0, 1, 2]:
    for j, spline, title in zip(range(3), splines, titles):
        axs[i, j].plot(xs, spline(xs, nu=i),'o')
        axs[i, j].plot(xnew, spline(xnew, nu=i),'-')
        axs[i, j].set_title(f'{title}, deriv={i}')

plt.tight_layout()
plt.show()
../../_images/extrapolation_examples-2.png

很明显,自然样条曲线在边界处确实具有零二阶导数,但外推是非线性的。bc_type="clamped" 显示类似的行为:一阶导数仅在边界处正好等于零。在所有情况下,外推都是通过扩展样条曲线的第一个和最后一个多项式片段来完成的,无论它们是什么。

一种强制外推的可能方法是扩展插值域,以添加第一个和最后一个多项式片段,这些片段具有所需属性。

在这里,我们使用 CubicSpline 超类 PPolyextend 方法来添加两个额外的断点,并确保附加的多项式片段保持导数的值。然后,使用这两个附加区间进行外推。

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline

def add_boundary_knots(spline):
    """
    Add knots infinitesimally to the left and right.

    Additional intervals are added to have zero 2nd and 3rd derivatives,
    and to maintain the first derivative from whatever boundary condition
    was selected. The spline is modified in place.
    """
    # determine the slope at the left edge
    leftx = spline.x[0]
    lefty = spline(leftx)
    leftslope = spline(leftx, nu=1)

    # add a new breakpoint just to the left and use the
    # known slope to construct the PPoly coefficients.
    leftxnext = np.nextafter(leftx, leftx - 1)
    leftynext = lefty + leftslope*(leftxnext - leftx)
    leftcoeffs = np.array([0, 0, leftslope, leftynext])
    spline.extend(leftcoeffs[..., None], np.r_[leftxnext])

    # repeat with additional knots to the right
    rightx = spline.x[-1]
    righty = spline(rightx)
    rightslope = spline(rightx,nu=1)
    rightxnext = np.nextafter(rightx, rightx + 1)
    rightynext = righty + rightslope * (rightxnext - rightx)
    rightcoeffs = np.array([0, 0, rightslope, rightynext])
    spline.extend(rightcoeffs[..., None], np.r_[rightxnext])

xs = [1, 2, 3, 4, 5, 6, 7, 8]
ys = [4.5, 3.6, 1.6, 0.0, -3.3, -3.1, -1.8, -1.7]

notaknot = CubicSpline(xs,ys, bc_type='not-a-knot')
# not-a-knot does not require additional intervals

natural = CubicSpline(xs,ys, bc_type='natural')
# extend the natural natural spline with linear extrapolating knots
add_boundary_knots(natural)

clamped = CubicSpline(xs,ys, bc_type='clamped')
# extend the clamped spline with constant extrapolating knots
add_boundary_knots(clamped)

xnew = np.linspace(min(xs) - 5, max(xs) + 5, 201)

fig, axs = plt.subplots(3, 3,figsize=(12,12))

splines = [notaknot, natural, clamped]
titles = ['not-a-knot', 'natural', 'clamped']

for i in [0, 1, 2]:
    for j, spline, title in zip(range(3), splines, titles):
        axs[i, j].plot(xs, spline(xs, nu=i),'o')
        axs[i, j].plot(xnew, spline(xnew, nu=i),'-')
        axs[i, j].set_title(f'{title}, deriv={i}')

plt.tight_layout()
plt.show()
../../_images/extrapolation_examples-3.png

手动实现渐近线#

之前扩展插值域的技巧依赖于 CubicSpline.extend 方法。一种更通用的替代方法是实现一个包装器,显式地处理超出边界的情况。让我们考虑一个实例。

设置#

假设我们要在 \(a\) 的给定值下求解方程

\[a x = 1/\tan{x}\;.\]

(这些类型的方程出现在求解量子粒子的能级问题中)。为简单起见,我们只考虑 \(x\in (0, \pi/2)\)

一次求解这个方程很简单

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import brentq

def f(x, a):
    return a*x - 1/np.tan(x)

a = 3
x0 = brentq(f, 1e-16, np.pi/2, args=(a,))   # here we shift the left edge
                                            # by a machine epsilon to avoid
                                            # a division by zero at x=0
xx = np.linspace(0.2, np.pi/2, 101)
plt.plot(xx, a*xx, '--')
plt.plot(xx, 1/np.tan(xx), '--')
plt.plot(x0, a*x0, 'o', ms=12)
plt.text(0.1, 0.9, fr'$x_0 = {x0:.3f}$',
               transform=plt.gca().transAxes, fontsize=16)
plt.show()
../../_images/extrapolation_examples-4.png

但是,如果我们需要多次求解(例如,由于 tan 函数的周期性,要找到一系列根),对 scipy.optimize.brentq 的重复调用会非常耗时。

为了避免这种困难,我们对 \(y = ax - 1/\tan{x}\) 进行制表,并在制表的网格上对其进行插值。实际上,我们将使用插值:我们插值 \(x\) 相对于 \(у\) 的值。这样,求解原始方程就变成了在零 \(y\) 参数处对插值函数进行评估。

为了提高插值精度,我们将利用表格函数导数的知识。我们将使用 BPoly.from_derivatives 来构建一个三次插值函数(等效地,我们可以使用 CubicHermiteSpline)。

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import BPoly

def f(x, a):
    return a*x - 1/np.tan(x)

xleft, xright = 0.2, np.pi/2
x = np.linspace(xleft, xright, 11)

fig, ax = plt.subplots(1, 2, figsize=(12, 4))

for j, a in enumerate([3, 93]):
    y = f(x, a)
    dydx = a + 1./np.sin(x)**2    # d(ax - 1/tan(x)) / dx
    dxdy = 1 / dydx               # dx/dy = 1 / (dy/dx)

    xdx = np.c_[x, dxdy]
    spl = BPoly.from_derivatives(y, xdx)   # inverse interpolation

    yy = np.linspace(f(xleft, a), f(xright, a), 51)
    ax[j].plot(yy, spl(yy), '--')
    ax[j].plot(y, x, 'o')
    ax[j].set_xlabel(r'$y$')
    ax[j].set_ylabel(r'$x$')
    ax[j].set_title(rf'$a = {a}$')

    ax[j].plot(0, spl(0), 'o', ms=12)
    ax[j].text(0.1, 0.85, fr'$x_0 = {spl(0):.3f}$',
               transform=ax[j].transAxes, fontsize=18)
    ax[j].grid(True)
plt.tight_layout()
plt.show()
../../_images/extrapolation_examples-5.png

注意,对于 \(a=3\)spl(0) 与上面的 brentq 调用一致,而对于 \(a = 93\),差异很大。该过程对于较大的 \(a\) 开始失败的原因是,直线 \(y = ax\) 趋向于垂直轴,而原始方程的根趋向于 \(x=0\)。由于我们在有限网格上对原始函数进行了表格化,spl(0) 对于过大的 \(a\) 值涉及外推。依赖外推容易导致精度下降,最好避免。

使用已知的渐近线#

观察原始方程,我们注意到对于 \(x\to 0\)\(\tan(x) = x + O(x^3)\),原始方程变为

\[ax = 1/x \;,\]

因此,对于 \(a \gg 1\)\(x_0 \approx 1/\sqrt{a}\)

我们将利用它来构建一个类,该类在超出范围的数据中切换从插值到使用这种已知的渐近行为。一个简单的实现可能如下所示

class RootWithAsymptotics:
   def __init__(self, a):

       # construct the interpolant
       xleft, xright = 0.2, np.pi/2
       x = np.linspace(xleft, xright, 11)

       y = f(x, a)
       dydx = a + 1./np.sin(x)**2    # d(ax - 1/tan(x)) / dx
       dxdy = 1 / dydx               # dx/dy = 1 / (dy/dx)

       # inverse interpolation
       self.spl = BPoly.from_derivatives(y, np.c_[x, dxdy])
       self.a = a

   def root(self):
       out = self.spl(0)
       asympt = 1./np.sqrt(self.a)
       return np.where(spl.x.min() < asympt, out, asympt)

然后

>>> r = RootWithAsymptotics(93)
>>> r.root()
array(0.10369517)

这与外推结果不同,并且与 brentq 调用一致。

注意,这个实现是故意简化的。从 API 的角度来看,您可能希望改而实现 __call__ 方法,以便 xy 的完全依赖关系可用。从数值的角度来看,需要做更多工作来确保插值和渐近线之间的切换发生在渐近线区域的足够深处,以便得到的函数在切换点足够平滑。

此外,在这个例子中,我们人为地将问题限制为只考虑 tan 函数的单个周期性区间,并且只处理了 \(a > 0\)。对于 \(a\) 的负值,我们需要实现其他渐近线,对于 \(x\to \pi\)

但是基本思想是一样的。

D > 1 中进行外推#

在包装类或函数中手动实现外推的基本思想可以很容易地推广到更高维度。例如,我们考虑使用 CloughTocher2DInterpolator 对二维数据进行 C1 平滑插值问题。默认情况下,它使用 nan 来填充超出范围的值,而我们希望改为使用每个查询点的最近邻的值。

由于 CloughTocher2DInterpolator 接受二维数据或数据的 Delaunay 三角剖分,因此查找查询点的最近邻的有效方法是构建三角剖分(使用 scipy.spatial 工具)并用它来查找数据凸包上的最近邻。

我们将使用一种更简单、更朴素的方法,并依赖于使用 NumPy 广播遍历整个数据集。

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CloughTocher2DInterpolator as CT

def my_CT(xy, z):
    """CT interpolator + nearest-neighbor extrapolation.

    Parameters
    ----------
    xy : ndarray, shape (npoints, ndim)
        Coordinates of data points
    z : ndarray, shape (npoints)
        Values at data points

    Returns
    -------
    func : callable
        A callable object which mirrors the CT behavior,
        with an additional neareast-neighbor extrapolation
        outside of the data range.
    """
    x = xy[:, 0]
    y = xy[:, 1]
    f = CT(xy, z)

    # this inner function will be returned to a user
    def new_f(xx, yy):
        # evaluate the CT interpolator. Out-of-bounds values are nan.
        zz = f(xx, yy)
        nans = np.isnan(zz)

        if nans.any():
            # for each nan point, find its nearest neighbor
            inds = np.argmin(
                (x[:, None] - xx[nans])**2 +
                (y[:, None] - yy[nans])**2
                , axis=0)
            # ... and use its value
            zz[nans] = z[inds]
        return zz

    return new_f

# Now illustrate the difference between the original ``CT`` interpolant
# and ``my_CT`` on a small example:

x = np.array([1, 1, 1, 2, 2, 2, 4, 4, 4])
y = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
z = np.array([0, 7, 8, 3, 4, 7, 1, 3, 4])

xy = np.c_[x, y]
lut = CT(xy, z)
lut2 = my_CT(xy, z)

X = np.linspace(min(x) - 0.5, max(x) + 0.5, 71)
Y = np.linspace(min(y) - 0.5, max(y) + 0.5, 71)
X, Y = np.meshgrid(X, Y)

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

ax.plot_wireframe(X, Y, lut(X, Y), label='CT')
ax.plot_wireframe(X, Y, lut2(X, Y), color='m',
                  cstride=10, rstride=10, alpha=0.7, label='CT + n.n.')

ax.scatter(x, y, z,  'o', color='k', s=48, label='data')
ax.legend()
plt.tight_layout()
../../_images/extrapolation_examples-6.png