kmeans2#
- scipy.cluster.vq.kmeans2(data, k, iter=10, thresh=1e-05, minit='random', missing='warn', check_finite=True, *, seed=None)[source]#
使用 k-means 算法将一组观察值分类为 k 个簇。
该算法尝试最小化观察值和质心之间的欧几里得距离。包含了多种初始化方法。
- 参数:
- datandarray
“M” 个“N” 维观察值的“M” 行“N” 列数组,或者“M” 个 1 维观察值的长度为“M” 的数组。
- kint 或 ndarray
要形成的簇的数目,以及要生成的质心的数目。如果 minit 初始化字符串为“matrix”,或如果给定的 ndarray 相反,则它被解释为要使用的初始簇。
- iterint,可选
运行的 k-means 算法迭代次数。请注意,这与kmeans函数的iters参数的含义不同。
- thresh浮点数,可选
(尚未使用)
- minit字符串,可选
初始化方法。可用方法包括“random”(随机)、“points”(点)、“++”(加加)和“matrix”(矩阵)
“random”(随机):根据从数据估算出的均值和方差,从高斯分布生成 k 个质心。
“points”(点):从数据中随机选取 k 个观察值(行)作为初始质心。
“++”(加加):根据 kmeans++ 方法选择 k 个观察值(小心播种)
“matrix”(矩阵):将k参数解释为k乘M(或1-D数据的长度为k的数组)的初始质心数组。
- missing字符串,可选
处理空集群的方法。可用方法包括“warn”(警告)和“raise”(引发)
“warn”(警告):发出警告并继续。
“raise”(引发):引发集群错误并终止算法。
- check_finite布尔值,可选
是否检查输入矩阵是否只包含有限数字。禁用可能会提高性能,但如果输入确实包含无穷大或 NaN,可能会导致问题(崩溃,无法终止)。默认值:True
- seed{无、整数、
numpy.random.Generator
、numpy.random.RandomState
}, 可选 用于初始化伪随机数生成器的种子。如果seed为无(或
numpy.random
),则使用numpy.random.RandomState
单例。如果seed为整数,则使用新的RandomState
实例,用seed播种。如果seed已经是Generator
或RandomState
实例,则使用该实例。默认值为无。
- 返回:
- 质心ndarray
一个“k”乘“N”的质心数组,在k-means的最后一次迭代中找到。
- 标签ndarray
label[i]是第i个观察值最接近的质心的代码或索引。
另请参阅
参考
[1]D. Arthur 和 S. Vassilvitskii,“k-means++:精细设置的优势”,2007 年 ACM-SIAM 离散算法第十八届年度研讨会论文集。
示例
>>> from scipy.cluster.vq import kmeans2 >>> import matplotlib.pyplot as plt >>> import numpy as np
创建一个具有 (100, 2) 形状的数组 z,其中包含来自三个多元正态分布的混合样本。
>>> rng = np.random.default_rng() >>> a = rng.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45) >>> b = rng.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30) >>> c = rng.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25) >>> z = np.concatenate((a, b, c)) >>> rng.shuffle(z)
计算三个聚类。
>>> centroid, label = kmeans2(z, 3, minit='points') >>> centroid array([[ 2.22274463, -0.61666946], # may vary [ 0.54069047, 5.86541444], [ 6.73846769, 4.01991898]])
每个聚类中有多少个点?
>>> counts = np.bincount(label) >>> counts array([29, 51, 20]) # may vary
绘制聚类。
>>> w0 = z[label == 0] >>> w1 = z[label == 1] >>> w2 = z[label == 2] >>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0') >>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1') >>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2') >>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids') >>> plt.axis('equal') >>> plt.legend(shadow=True) >>> plt.show()