scipy.special.huber#

scipy.special.huber(delta, r, out=None) = <ufunc 'huber'>#

Huber 损失函数。

\[\begin{split}\text{huber}(\delta, r) = \begin{cases} \infty & \delta < 0 \\ \frac{1}{2}r^2 & 0 \le \delta, | r | \le \delta \\ \delta ( |r| - \frac{1}{2}\delta ) & \text{otherwise} \end{cases}\end{split}\]
参数:
deltandarray

输入数组,指示二次 vs. 线性损失变更点。

rndarray

输入数组,可能表示残差。

outndarray,可选

函数值的可选项输出数组

返回:
标量或 ndarray

计算的 Huber 损失函数值。

另请参见

pseudo_huber

此函数的平滑逼近

备注

huber 可用作稳健统计或机器学习中的损失函数,与常见的平方误差损失相比,可降低离群值的影响,大于 delta 大小的残差不进行平方 [1]

通常,r 表示残差,这是模型预测与数据之间的差值。然后,对于 \(|r|\leq\delta\)huber 类似于平方误差,对于 \(|r|>\delta\),类似于绝对误差。这样,Huber 损失经常在与平方误差损失函数相似的残差模型拟合中实现快速收敛,同时还能像绝对误差损失一样降低异常值的影响 (\(|r|>\delta\))。由于 \(\delta\) 是平方误差和绝对误差模式之间的临界值,因此必须针对每个问题仔细调整 \(\delta\) 的值。huber 也是凸函数,这使其适用于基于梯度的优化。

在 0.15.0 版本中添加。

参考

[1]

Peter Huber。“位置参数的稳健估计”,1964 年。统计年鉴。53 (1): 73 - 101。

示例

导入所有必需的模块。

>>> import numpy as np
>>> from scipy.special import huber
>>> import matplotlib.pyplot as plt

r=2 计算 delta=1 的函数

>>> huber(1., 2.)
1.5

通过为 delta 提供 NumPy 数组或列表来计算不同 delta 的函数。

>>> huber([1., 3., 5.], 4.)
array([3.5, 7.5, 8. ])

通过为 r 提供 NumPy 数组或列表来计算不同点处的函数。

>>> huber(2., np.array([1., 1.5, 3.]))
array([0.5  , 1.125, 4.   ])

可以通过为 deltar 提供具有兼容广播形状的数组,为不同的 deltar 计算函数。

>>> r = np.array([1., 2.5, 8., 10.])
>>> deltas = np.array([[1.], [5.], [9.]])
>>> print(r.shape, deltas.shape)
(4,) (3, 1)
>>> huber(deltas, r)
array([[ 0.5  ,  2.   ,  7.5  ,  9.5  ],
       [ 0.5  ,  3.125, 27.5  , 37.5  ],
       [ 0.5  ,  3.125, 32.   , 49.5  ]])

绘制不同 delta 的函数。

>>> x = np.linspace(-4, 4, 500)
>>> deltas = [1, 2, 3]
>>> linestyles = ["dashed", "dotted", "dashdot"]
>>> fig, ax = plt.subplots()
>>> combined_plot_parameters = list(zip(deltas, linestyles))
>>> for delta, style in combined_plot_parameters:
...     ax.plot(x, huber(delta, x), label=fr"$\delta={delta}$", ls=style)
>>> ax.legend(loc="upper center")
>>> ax.set_xlabel("$x$")
>>> ax.set_title(r"Huber loss function $h_{\delta}(x)$")
>>> ax.set_xlim(-4, 4)
>>> ax.set_ylim(0, 8)
>>> plt.show()
../../_images/scipy-special-huber-1.png