kmeans2#
- scipy.cluster.vq.kmeans2(data, k, iter=10, thresh=1e-05, minit='random', missing='warn', check_finite=True, *, rng=None)[源代码]#
使用 k-均值算法将一组观测值分类到 k 个聚类中。
该算法尝试最小化观测值和质心之间的欧几里得距离。包括几种初始化方法。
- 参数:
- datandarray
一个 ‘M’ 行 ‘N’ 列的数组,包含 ‘N’ 维度的 ‘M’ 个观测值,或者一个长度为 ‘M’ 的数组,包含 ‘M’ 个一维观测值。
- kint 或 ndarray
要形成的聚类数量以及要生成的质心数量。如果 minit 初始化字符串为 ‘matrix’,或者如果给定了 ndarray,则将其解释为要使用的初始聚类。
- iterint, 可选
要运行的 k-均值算法的迭代次数。请注意,这与 kmeans 函数的 iters 参数的含义不同。
- threshfloat, 可选
(尚未使用)
- minitstr, 可选
初始化方法。 可用的方法有 ‘random’、‘points’、‘++’ 和 ‘matrix’
‘random’:从高斯分布生成 k 个质心,其均值和方差根据数据估计。
‘points’:从数据中随机选择 k 个观测值(行)作为初始质心。
‘++’:根据 kmeans++ 方法选择 k 个观测值(仔细播种)
‘matrix’:将 k 参数解释为 k 行 M 列(或一维数据的长度为 k 的数组)的初始质心数组。
- missingstr, 可选
处理空聚类的方法。 可用的方法有 ‘warn’ 和 ‘raise’
‘warn’:给出警告并继续。
‘raise’:引发 ClusterError 并终止算法。
- check_finitebool, 可选
是否检查输入矩阵是否仅包含有限数字。禁用可能会提高性能,但如果输入包含无穷大或 NaN,可能会导致问题(崩溃、无法终止)。默认值:True
- rng{None, int,
numpy.random.Generator
}, 可选 如果通过关键字传递 rng,则会将
numpy.random.Generator
以外的类型传递给numpy.random.default_rng
以实例化Generator
。如果 rng 已经是Generator
实例,则使用提供的实例。 指定 rng 以实现可重复的函数行为。如果此参数通过位置传递,或者 seed 通过关键字传递,则应用参数 seed 的旧行为
如果 seed 为 None (或
numpy.random
),则使用numpy.random.RandomState
单例。如果 seed 是一个 int,则使用新的
RandomState
实例,并使用 seed 进行种子设置。如果 seed 已经是
Generator
或RandomState
实例,则使用该实例。
在版本 1.15.0 中更改: 作为从使用
numpy.random.RandomState
过渡到numpy.random.Generator
的 SPEC-007 的一部分,此关键字已从 seed 更改为 rng。在过渡期内,两个关键字将继续工作,但一次只能指定一个。在过渡期之后,使用 seed 关键字的函数调用将发出警告。上面概述了 seed 和 rng 的行为,但新代码中应仅使用 rng 关键字。
- 返回:
- centroidndarray
一个 ‘k’ 行 ‘N’ 列的数组,包含 k-均值最后一次迭代中找到的质心。
- labelndarray
label[i] 是第 i 个观测值最接近的质心的代码或索引。
参见
参考文献
[1]D. Arthur 和 S. Vassilvitskii,“k-means++:仔细播种的优点”,第十八届 ACM-SIAM 离散算法年度研讨会论文集,2007 年。
示例
>>> from scipy.cluster.vq import kmeans2 >>> import matplotlib.pyplot as plt >>> import numpy as np
创建 z,一个形状为 (100, 2) 的数组,其中包含来自三个多元正态分布的混合样本。
>>> rng = np.random.default_rng() >>> a = rng.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45) >>> b = rng.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30) >>> c = rng.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25) >>> z = np.concatenate((a, b, c)) >>> rng.shuffle(z)
计算三个聚类。
>>> centroid, label = kmeans2(z, 3, minit='points') >>> centroid array([[ 2.22274463, -0.61666946], # may vary [ 0.54069047, 5.86541444], [ 6.73846769, 4.01991898]])
每个聚类中有多少个点?
>>> counts = np.bincount(label) >>> counts array([29, 51, 20]) # may vary
绘制聚类。
>>> w0 = z[label == 0] >>> w1 = z[label == 1] >>> w2 = z[label == 2] >>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0') >>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1') >>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2') >>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids') >>> plt.axis('equal') >>> plt.legend(shadow=True) >>> plt.show()