本文由 SparkML 系列作者 SparkML 授权转发,原文出处:【Spark ML系列】 Kmeans聚类算法由来原理方法示例源码分析 - 知乎
Spark Kmeans聚类算法由来原理方法示例源码分析
K-means++ 是一种改进的 K-means 聚类算法,旨在选择更好的初始聚类中心点。以下是 Spark 中 K-means++ 的全面介绍
Kmeans 由来
传统的 K-means 算法是通过随机选择初始聚类中心点来开始聚类过程。然而,随机选择的初始中心可能导致收敛到局部最优解的风险增加,即结果聚类质量较差。
为了解决这个问题,Arthur 和 Vassilvitskii 在 2007 年提出了 K-means++ 算法。它使用一种概率性的方法来选择初始聚类中心,从而使得初始中心能够更好地代表数据集,减少陷入局部最优解的可能性。
Kmeans 原理
K-means++ 的初始化过程如下:
- 从数据集中随机选择一个样本作为第一个聚类中心。
- 对于每个样本,计算它与已选择的聚类中心的最短距离(即到最近的聚类中心的距离)。
- 根据每个样本与已选择的聚类中心的最短距离,以概率分布的方式选择下一个聚类中心。距离较远的样本被选择为下一个聚类中心的概率较大。
- 重复步骤 2 和 3,直到选择出 k 个聚类中心。
K-means++ 的核心思想是通过引入概率选择机制,使得选择的初始聚类中心在整个数据集中分布更加均匀。这样做的好处是能够更好地代表数据集,从而改善最终聚类结果的质量。
在 Spark 中,K-means++ 算法被用作默认的初始化算法,它可以通过设置 initMode
参数为 "k-means||"
来启用。通过使用 K-means++ 初始化算法,Spark 能够提供更可靠和高质量的聚类结果。
示例 RDD 版
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
val data = sc.textFile("data/mllib/kmeans_data.txt")
val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache()
// Cluster the data into two classes using KMeans
val numClusters = 2
val numIterations = 20
val clusters = KMeans.train(parsedData, numClusters, numIterations)
// Evaluate clustering by computing Within Set Sum of Squared Errors
val WSSSE = clusters.computeCost(parsedData)
println(s"Within Set Sum of Squared Errors = $WSSSE")
// Save and load model
clusters.save(sc, "target/org/apache/spark/KMeansExample/KMeansModel")
val sameModel = KMeansModel.load(sc, "target/org/apache/spark/KMeansExample/KMeansModel")
示例 DataFrame 版本
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.evaluation.ClusteringEvaluator
// Loads data.
val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
// Trains a k-means model.
val kmeans = new KMeans().setK(2).setSeed(1L)
val model = kmeans.fit(dataset)
// Make predictions
val predictions = model.transform(dataset)
// Evaluate clustering by computing Silhouette score
val evaluator = new ClusteringEvaluator()
val silhouette = evaluator.evaluate(predictions)
println(s"Silhouette with squared euclidean distance = $silhouette")
// Shows the result.
println("Cluster Centers: ")
model.clusterCenters.foreach(println)
方法详细说明
load:从指定路径加载 KMeans 模型。
load(path: String): KMeansModel
从指定路径加载保存的 KMeans 模型。
参数:path:模型的保存路径。
返回值:KMeansModel 对象,表示加载的 KMeans 模型。
read:返回一个用于读取 KMeans 模型的 MLReader 对象。
read(): MLReader[KMeansModel]
返回一个用于读取 KMeans 模型的 MLReader 对象。
返回值:MLReader[KMeansModel] 对象,用于读取 KMeans 模型。
k:获取聚类数目(k)的参数。
k: IntParam
获取聚类数目(k)的参数。注意,实际返回的聚类数目可能少于 k,例如如果数据集中的不同点数量少于 k。
返回值:IntParam 对象,表示聚类数目。
initMode:获取初始化算法的参数。
initMode: Param[String]
获取初始化算法的参数。可以选择 “random” 以随机选择初始聚类中心,或选择 “k-means||” 使用 k-means++ 的并行变种算法。
返回值:Param[String] 对象,表示初始化算法。
initSteps:获取 k-means|| 初始化模式的步数参数。
initSteps: IntParam
获取 k-means|| 初始化模式的步数参数。这是一个高级设置,通常默认值 2 已经足够。
返回值:IntParam 对象,表示初始化步数。
solver:获取优化方法的参数。
solver: Param[String]
获取用于优化的方法的参数。支持的选项有:“auto”、“row” 和 “block”。默认为 “auto”。
返回值:Param[String] 对象,表示优化方法。
maxBlockSizeInMB:获取将输入数据堆叠到块中的最大内存限制的参数。
maxBlockSizeInMB: DoubleParam
获取将输入数据堆叠到块中的最大内存限制的参数。默认值为 0.0,表示自动选择最佳值。
返回值:DoubleParam 对象,表示最大内存限制。
distanceMeasure:获取距离度量的参数。
distanceMeasure: Param[String]
获取距离度量的参数。支持的选项有:“euclidean” 和 “cosine”。
返回值:Param[String] 对象,表示距离度量。
tol:获取迭代算法的收敛容忍度的参数。
tol: DoubleParam
获取迭代算法的收敛容忍度的参数。必须大于等于 0。
返回值:DoubleParam 对象,表示容忍度。
seed:获取随机种子的参数。
seed: LongParam
获取随机种子的参数。
返回值:LongParam 对象,表示随机种子。
maxIter:获取最大迭代次数的参数。
maxIter: IntParam
获取最大迭代次数的参数。必须大于等于 0。
返回值:IntParam 对象,表示最大迭代次数。
fit:将模型应用于输入数据,拟合出一个模型。
fit(dataset: Dataset[_]): KMeansModel
将模型应用于输入数据,并拟合出一个 KMeansModel 模型。
参数:dataset:要拟合的输入数据集。
返回值:KMeansModel 对象,表示拟合得到的 KMeans 模型。
transformSchema:检查转换的有效性,并根据输入模式推导出输出模式。
transformSchema(schema: StructType): StructType
检查转换的有效性,并根据输入模式推导出输出模式。
参数:schema:输入数据的模式。
返回值:StructType 对象,表示输出数据的模式。
中文源码
object KMeans
/**
* 用于调用K-means聚类的顶层方法。
*/
@Since("0.8.0")
object KMeans {
// 初始化模式名称
@Since("0.8.0")
val RANDOM = "随机"
@Since("0.8.0")
val K_MEANS_PARALLEL = "k-means||"
/**
* 使用给定的参数训练一个K-means模型。
*
* @param data 训练数据点,作为`Vector`类型的RDD。
* @param k 要创建的簇的数量。
* @param maxIterations 允许的最大迭代次数。
* @param initializationMode 初始化算法。可以是"random"或"k-means||"。(默认值:"k-means||")
* @param seed 随机种子用于簇初始化。默认是基于系统时间生成种子。
*/
@Since("2.1.0")
def train(
data: RDD[Vector],
k: Int,
maxIterations: Int,
initializationMode: String,
seed: Long): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.setInitializationMode(initializationMode)
.setSeed(seed)
.run(data)
}
/**
* 使用给定的参数训练一个K-means模型。
*
* @param data 训练数据点,作为`Vector`类型的RDD。
* @param k 要创建的簇的数量。
* @param maxIterations 允许的最大迭代次数。
* @param initializationMode 初始化算法。可以是"random"或"k-means||"。(默认值:"k-means||")
*/
@Since("2.1.0")
def train(
data: RDD[Vector],
k: Int,
maxIterations: Int,
initializationMode: String): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.setInitializationMode(initializationMode)
.run(data)
}
/**
* 使用给定的参数训练一个K-means模型。
*
* @param data 训练数据点,作为`Vector`类型的RDD。
* @param k 要创建的簇的数量。
* @param maxIterations 允许的最大迭代次数。
* @param runs 从Spark 2.0.0开始,此参数不起作用。
* @param initializationMode 初始化算法。可以是"random"或"k-means||"。(默认值:"k-means||")
* @param seed 随机种子用于簇初始化。默认是基于系统时间生成种子。
*/
@Since("1.3.0")
@deprecated("使用不带'runs'参数的train方法", "2.1.0")
def train(
data: RDD[Vector],
k: Int,
maxIterations: Int,
runs: Int,
initializationMode: String,
seed: Long): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.setInitializationMode(initializationMode)
.setSeed(seed)
.run(data)
}
/**
* 使用给定的参数训练一个K-means模型。
*
* @param data 训练数据点,作为`Vector`类型的RDD。
* @param k 要创建的簇的数量。
* @param maxIterations 允许的最大迭代次数。
* @param runs 从Spark 2.0.0开始,此参数不起作用。
* @param initializationMode 初始化算法。可以是"random"或"k-means||"。(默认值:"k-means||")
*/
@Since("0.8.0")
@deprecated("使用不带'runs'参数的train方法", "2.1.0")
def train(
data: RDD[Vector],
k: Int,
maxIterations: Int,
runs: Int,
initializationMode: String): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.setInitializationMode(initializationMode)
.run(data)
}
/**
* 使用指定的参数和未指定参数的默认值训练一个K-means模型。
*/
@Since("0.8.0")
def train(
data: RDD[Vector],
k: Int,
maxIterations: Int): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.run(data)
}
/**
* 使用指定的参数和未指定参数的默认值训练一个K-means模型。
*/
@Since("0.8.0")
@deprecated("使用不带'runs'参数的train方法", "2.1.0")
def train(
data: RDD[Vector],
k: Int,
maxIterations: Int,
runs: Int): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.run(data)
}
private[spark] def validateInitMode(initMode: String): Boolean = {
initMode match {
case KMeans.RANDOM => true
case KMeans.K_MEANS_PARALLEL => true
case _ => false
}
}
}
/**
* 带有其范数的向量,用于快速计算距离。
*/
private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double)
extends Serializable {
def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0))
def this(array: Array[Double]) = this(Vectors.dense(array))
/** 将向量转换为稠密向量。 */
def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
}
class KMeans
/**
* K-means聚类算法,使用k-means++的初始化模式(由Bahmani等人的k-means||算法实现)。
*
* 这是一个迭代算法,会对数据进行多次遍历,因此传递给它的RDD应该由用户缓存。
*/
@Since("0.8.0")
class KMeans private (
private var k: Int,
private var maxIterations: Int,
private var initializationMode: String,
private var initializationSteps: Int,
private var epsilon: Double,
private var seed: Long,
private var distanceMeasure: String) extends Serializable with Logging {
@Since("0.8.0")
private def this(k: Int, maxIterations: Int, initializationMode: String, initializationSteps: Int,
epsilon: Double, seed: Long) =
this(k, maxIterations, initializationMode, initializationSteps,
epsilon, seed, DistanceMeasure.EUCLIDEAN)
/**
* 使用默认参数构造KMeans实例:{k: 2, maxIterations: 20,
* initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: 随机数,
* distanceMeasure: "euclidean"}。
*/
@Since("0.8.0")
def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong(),
DistanceMeasure.EUCLIDEAN)
/**
* 聚类簇的数量(k)。
*
* @note 返回的聚类簇数量可能少于k,例如,如果要聚类的不同点数量少于k。
*/
@Since("1.4.0")
def getK: Int = k
/**
* 设置聚类簇的数量(k)。
*
* @note 返回的聚类簇数量可能少于k,例如,如果要聚类的不同点数量少于k。默认值:2。
*/
@Since("0.8.0")
def setK(k: Int): this.type = {
require(k > 0,
s"聚类簇数量必须为正数,但得到了 ${k}")
this.k = k
this
}
/**
* 允许的最大迭代次数。
*/
@Since("1.4.0")
def getMaxIterations: Int = maxIterations
/**
* 设置允许的最大迭代次数。默认值:20。
*/
@Since("0.8.0")
def setMaxIterations(maxIterations: Int): this.type = {
require(maxIterations >= 0,
s"最大迭代次数必须为非负数,但得到了 ${maxIterations}")
this.maxIterations = maxIterations
this
}
/**
* 初始化算法。可以是"random"或者"k-means||"。
*/
@Since("1.4.0")
def getInitializationMode: String = initializationMode
/**
* 设置初始化算法。可以是"random"(随机选择初始聚类中心点),或者"k-means||"(使用k-means++的并行版本,
* 由Bahmani等人提出,参考Scalable K-Means++,VLDB 2012)。默认值:"k-means||"。
*/
@Since("0.8.0")
def setInitializationMode(initializationMode: String): this.type = {
KMeans.validateInitMode(initializationMode)
this.initializationMode = initializationMode
this
}
/**
* 这个函数自Spark 2.0.0开始没有效果。
*/
@Since("1.4.0")
@deprecated("这个函数没有效果,并且总是返回1", "2.1.0")
def getRuns: Int = {
logWarning("自Spark 2.0.0开始,获取运行次数没有效果。")
1
}
/**
* 这个函数自Spark 2.0.0开始没有效果。
*/
@Since("0.8.0")
@deprecated("这个函数没有效果", "2.1.0")
def setRuns(runs: Int): this.type = {
logWarning("自Spark 2.0.0开始,设置运行次数没有效果。")
this
}
/**
* k-means||初始化模式的步骤数量。
*/
@Since("1.4.0")
def getInitializationSteps: Int = initializationSteps
/**
* 设置k-means||初始化模式的步骤数量。这是一个高级设置,一般默认值2就足够了。默认值:2。
*/
@Since("0.8.0")
def setInitializationSteps(initializationSteps: Int): this.type = {
require(initializationSteps > 0,
s"初始化步骤数量必须为正数,但得到了 ${initializationSteps}")
this.initializationSteps = initializationSteps
this
}
/**
* 聚类中心点收敛的距离阈值。
*/
@Since("1.4.0")
def getEpsilon: Double = epsilon
/**
* 设置聚类中心点收敛的距离阈值。如果所有中心点移动的欧氏距离小于该阈值,算法将停止迭代。默认值:0.0001。
*/
@Since("0.8.0")
def setEpsilon(epsilon: Double): this.type = {
require(epsilon >= 0,
s"距离阈值必须为非负数,但得到了 ${epsilon}")
this.epsilon = epsilon
this
}
/**
* 聚类初始化的随机种子。
*/
@Since("1.4.0")
def getSeed: Long = seed
/**
* 设置聚类初始化的随机种子。
*/
@Since("1.4.0")
def setSeed(seed: Long): this.type = {
this.seed = seed
this
}
/**
* 算法使用的距离度量。
*/
@Since("2.4.0")
def getDistanceMeasure: String = distanceMeasure
/**
* 设置算法使用的距离度量。
*/
@Since("2.4.0")
def setDistanceMeasure(distanceMeasure: String): this.type = {
DistanceMeasure.validateDistanceMeasure(distanceMeasure)
this.distanceMeasure = distanceMeasure
this
}
// 初始聚类中心可以作为KMeansModel对象提供,而不是使用随机初始化或k-means||初始化模式
private var initialModel: Option[KMeansModel] = None
/**
* 设置初始起始点,绕过随机初始化或k-means||
* 必须满足 model.k == this.k 的条件,否则会引发 IllegalArgumentException。
*/
@Since("1.4.0")
def setInitialModel(model: KMeansModel): this.type = {
require(model.k == k, "聚类簇数量不匹配")
initialModel = Some(model)
this
}
/**
* 在给定的数据集上训练一个K-means模型;为了获得高性能,`data`应该被缓存,
* 因为这是一个迭代算法。
*/
@Since("0.8.0")
def run(data: RDD[Vector]): KMeansModel = {
run(data, None)
}
private[spark] def run(
data: RDD[Vector],
instr: Option[Instrumentation]): KMeansModel = {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("输入数据没有直接缓存,如果其父RDD也没有缓存,可能会影响性能。")
}
// 计算向量的平方范数并缓存
val norms = data.map(Vectors.norm(_, 2.0))
norms.persist()
val zippedData = data.zip(norms).map { case (v, norm) =>
new VectorWithNorm(v, norm)
}
val model = runAlgorithm(zippedData, instr)
norms.unpersist()
// 在运行结束时也进行警告,提高可见性
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("输入数据没有直接缓存,如果其父RDD也没有缓存,可能会影响性能。")
}
model
}
/**
* K-Means算法的实现。
*/
private def runAlgorithm(
data: RDD[VectorWithNorm],
instr: Option[Instrumentation]): KMeansModel = {
val sc = data.sparkContext
val initStartTime = System.nanoTime()
val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)
val centers = initialModel match {
case Some(kMeansCenters) =>
kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data, distanceMeasureInstance)
}
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(f"使用 $initializationMode 初始化花费了 $initTimeInSeconds%.3f 秒。")
var converged = false
var cost = 0.0
var iteration = 0
val iterationStartTime = System.nanoTime()
instr.foreach(_.logNumFeatures(centers.head.vector.size))
// 使用Lloyd算法执行迭代,直到收敛或达到最大迭代次数
while (iteration < maxIterations && !converged) {
val costAccum = sc.doubleAccumulator
val bcCenters = sc.broadcast(centers)
// 寻找新的聚类中心点
val collected = data.mapPartitions { points =>
val thisCenters = bcCenters.value
val dims = thisCenters.head.vector.size
val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
val counts = Array.fill(thisCenters.length)(0L)
points.foreach { point =>
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
costAccum.add(cost)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
counts(bestCenter) += 1
}
counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
}.collectAsMap()
if (iteration == 0) {
instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
}
val newCenters = collected.mapValues { case (sum, count) =>
distanceMeasureInstance.centroid(sum, count)
}
bcCenters.destroy(blocking = false)
// 更新聚类中心和成本
converged = true
newCenters.foreach { case (j, newCenter) =>
if (converged &&
!distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) {
converged = false
}
centers(j) = newCenter
}
cost = costAccum.value
iteration += 1
}
val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
logInfo(f"迭代花费了 $iterationTimeInSeconds%.3f 秒。")
if (iteration == maxIterations) {
logInfo(s"KMeans达到最大迭代次数:$maxIterations。")
} else {
logInfo(s"KMeans在 $iteration 次迭代中收敛。")
}
logInfo(s"成本为 $cost。")
new KMeansModel(centers.map(_.vector), distanceMeasure, cost, iteration)
}
/**
* 随机初始化一组聚类中心点。
*/
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
// 无放回选择;如果数据的不同点数量小于k,可能仍会产生重复点,
// 因此要去除重复的聚类中心,以保持与k-means||在相同情况下的行为一致
data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt())
.map(_.vector).distinct.map(new VectorWithNorm(_))
}
/**
* 使用Bahmani等人提出的k-means||算法(Scalable K-Means++,VLDB 2012)初始化一组聚类中心。
* 这是k-means++的变体,通过从随机中心开始,并进行多次迭代选择更多的中心,选择概率与其与当前聚类集的平方距离成比例,
* 来寻找不相似的聚类中心。它可以得到对最优聚类的可证明近似结果。
*
* 原始论文可在 http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf 找到。
*/
private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm],
distanceMeasureInstance: DistanceMeasure): Array[VectorWithNorm] = {
// 初始化空的中心点和点的成本。
var costs = data.map(_ => Double.PositiveInfinity)
// 将第一个中心点初始化为随机点。
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(false, 1, seed)
// 如果数据为空,可能会为空;提前失败并显示更好的消息:
require(sample.nonEmpty, s"从 $data 中没有可用样本")
val centers = ArrayBuffer[VectorWithNorm]()
var newCenters = Seq(sample.head.toDense)
centers ++= newCenters
// 在每一步中,平均采样2 * k个点,选择概率与其与中心的平方距离成比例。
// 注意,每次迭代只计算点与新中心点之间的距离。
var step = 0
val bcNewCentersList = ArrayBuffer[Broadcast[_]]()
while (step < initializationSteps) {
val bcNewCenters = data.context.broadcast(newCenters)
bcNewCentersList += bcNewCenters
val preCosts = costs
costs = data.zip(preCosts).map { case (point, cost) =>
math.min(distanceMeasureInstance.pointCost(bcNewCenters.value, point), cost)
}.persist(StorageLevel.MEMORY_AND_DISK)
val sumCosts = costs.sum()
bcNewCenters.unpersist(blocking = false)
preCosts.unpersist(blocking = false)
val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointCosts) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
pointCosts.filter { case (_, c) => rand.nextDouble() < 2.0 * c * k / sumCosts }.map(_._1)
}.collect()
newCenters = chosen.map(_.toDense)
centers ++= newCenters
step += 1
}
costs.unpersist(blocking = false)
bcNewCentersList.foreach(_.destroy(false))
val distinctCenters = centers.map(_.vector).distinct.map(new VectorWithNorm(_))
if (distinctCenters.size <= k) {
distinctCenters.toArray
} else {
// 最后,可能会有多于k个不同的候选中心点;根据数据集中映射到每个中心点的点的数量,
// 权衡每个候选中心点,并在加权的中心点上运行本地k-means++以选择k个中心点。
val bcCenters = data.context.broadcast(distinctCenters)
val countMap = data
.map(distanceMeasureInstance.findClosest(bcCenters.value, _)._1)
.countByValue()
bcCenters.destroy(blocking = false)
val myWeights = distinctCenters.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray
LocalKMeans.kMeansPlusPlus(0, distinctCenters.toArray, myWeights, k, 30)
}
}
}
class KMeansModel
/**
* K-means的聚类模型。每个点属于最近中心点的簇。
*/
@Since("0.8.0")
class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],
@Since("2.4.0") val distanceMeasure: String,
@Since("2.4.0") val trainingCost: Double,
private[spark] val numIter: Int)
extends Saveable with Serializable with PMMLExportable {
private val distanceMeasureInstance: DistanceMeasure =
DistanceMeasure.decodeFromString(distanceMeasure)
private val clusterCentersWithNorm =
if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_))
@Since("2.4.0")
private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) =
this(clusterCenters: Array[Vector], distanceMeasure, 0.0, -1)
@Since("1.1.0")
def this(clusterCenters: Array[Vector]) =
this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN, 0.0, -1)
/**
* 一个适用于Java的构造函数,接受Vector的Iterable。
*/
@Since("1.4.0")
def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray)
/**
* 聚类的总数。
*/
@Since("0.8.0")
def k: Int = clusterCentersWithNorm.length
/**
* 返回给定点所属的簇索引。
*/
@Since("0.8.0")
def predict(point: Vector): Int = {
distanceMeasureInstance.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1
}
/**
* 将给定的点映射到它们的簇索引。
*/
@Since("1.0.0")
def predict(points: RDD[Vector]): RDD[Int] = {
val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
points.map(p =>
distanceMeasureInstance.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
}
/**
* 将给定的点映射到它们的簇索引。
*/
@Since("1.0.0")
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
/**
* 计算该模型在给定数据上的K-means成本(点到最近中心点的平方距离之和)。
*/
@Since("0.8.0")
def computeCost(data: RDD[Vector]): Double = {
val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm)
val cost = data.map(p =>
distanceMeasureInstance.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p)))
.sum()
bcCentersWithNorm.destroy(blocking = false)
cost
}
@Since("1.4.0")
override def save(sc: SparkContext, path: String): Unit = {
KMeansModel.SaveLoadV2_0.save(sc, this, path)
}
override protected def formatVersion: String = "2.0"
}
object KMeansModel
@Since("1.4.0")
object KMeansModel extends Loader[KMeansModel] {
@Since("1.4.0")
override def load(sc: SparkContext, path: String): KMeansModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
val classNameV2_0 = SaveLoadV2_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
SaveLoadV1_0.load(sc, path)
case (className, "2.0") if className == classNameV2_0 =>
SaveLoadV2_0.load(sc, path)
case _ => throw new Exception(
s"KMeansModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)\n" +
s" ($classNameV2_0, 2.0)")
}
}
private case class Cluster(id: Int, point: Vector)
private object Cluster {
def apply(r: Row): Cluster = {
Cluster(r.getInt(0), r.getAs[Vector](1))
}
}
private[clustering] object SaveLoadV1_0 {
private val thisFormatVersion = "1.0"
private[clustering]
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
Cluster(id, p.vector)
}
spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String): KMeansModel = {
implicit val formats = DefaultFormats
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val k = (metadata \ "k").extract[Int]
val centroids = spark.read.parquet(Loader.dataPath(path))
Loader.checkSchema[Cluster](centroids.schema)
val localCentroids = centroids.rdd.map(Cluster.apply).collect()
assert(k == localCentroids.length)
new KMeansModel(localCentroids.sortBy(_.id).map(_.point))
}
}
private[clustering] object SaveLoadV2_0 {
private val thisFormatVersion = "2.0"
private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure)
~ ("trainingCost" -> model.trainingCost)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
Cluster(id, p.vector)
}
spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String): KMeansModel = {
implicit val formats = DefaultFormats
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val k = (metadata \ "k").extract[Int]
val centroids = spark.read.parquet(Loader.dataPath(path))
Loader.checkSchema[Cluster](centroids.schema)
val localCentroids = centroids.rdd.map(Cluster.apply).collect()
assert(k == localCentroids.length)
val distanceMeasure = (metadata \ "distanceMeasure").extract[String]
val trainingCost = (metadata \ "trainingCost").extract[Double]
new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure, trainingCost, -1)
}
}
}