spark SpearmanCorrelation 源码
spark SpearmanCorrelation 代码
文件路径:/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.stat.correlation
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors}
import org.apache.spark.rdd.RDD
/**
* Compute Spearman's correlation for two RDDs of the type RDD[Double] or the correlation matrix
* for an RDD of the type RDD[Vector].
*
* Definition of Spearman's correlation can be found at
* http://en.wikipedia.org/wiki/Spearman's_rank_correlation_coefficient
*/
private[stat] object SpearmanCorrelation extends Correlation with Logging {
/**
* Compute Spearman's correlation for two datasets.
*/
override def computeCorrelation(x: RDD[Double], y: RDD[Double]): Double = {
computeCorrelationWithMatrixImpl(x, y)
}
/**
* Compute Spearman's correlation matrix S, for the input matrix, where S(i, j) is the
* correlation between column i and j.
*/
override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = {
// ((columnIndex, value), rowUid)
val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) =>
vec.iterator.map(t => (t, uid))
}
// global sort by (columnIndex, value)
val sorted = colBased.sortByKey()
// assign global ranks (using average ranks for tied values)
val globalRanks = sorted.zipWithIndex().mapPartitions { iter =>
var preCol = -1
var preVal = Double.NaN
var startRank = -1.0
val cachedUids = ArrayBuffer.empty[Long]
val flush: () => Iterable[(Long, (Int, Double))] = () => {
val averageRank = startRank + (cachedUids.size - 1) / 2.0
val output = cachedUids.map { uid =>
(uid, (preCol, averageRank))
}
cachedUids.clear()
output
}
iter.flatMap { case (((j, v), uid), rank) =>
// If we see a new value or cachedUids is too big, we flush ids with their average rank.
if (j != preCol || v != preVal || cachedUids.size >= 10000000) {
val output = flush()
preCol = j
preVal = v
startRank = rank
cachedUids += uid
output
} else {
cachedUids += uid
Iterator.empty
}
} ++ flush()
}
// Replace values in the input matrix by their ranks compared with values in the same column.
// Note that shifting all ranks in a column by a constant value doesn't affect result.
val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) =>
// sort by column index and then convert values to a vector
Vectors.dense(iter.toSeq.sortBy(_._1).map(_._2).toArray)
}
PearsonCorrelation.computeCorrelationMatrix(groupedRanks)
}
}
相关信息
相关文章
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦