spark MultivariateOnlineSummarizer 源码
spark MultivariateOnlineSummarizer 代码
文件路径:/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.stat
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{Vector, Vectors}
/**
* MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean,
* variance, minimum, maximum, counts, and nonzero counts for instances in sparse or dense vector
* format in an online fashion.
*
* Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of
* the corresponding joint dataset.
*
* A numerically stable algorithm is implemented to compute the mean and variance of instances:
* Reference: <a href="http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance">
* variance-wiki</a>
* Zero elements (including explicit zero values) are skipped when calling add(),
* to have time complexity O(nnz) instead of O(n) for each column.
*
* For weighted instances, the unbiased estimation of variance is defined by the reliability
* weights:
* see <a href="https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights">
* Reliability weights (Wikipedia)</a>.
*/
@Since("1.1.0")
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
private var n = 0
private var currMean: Array[Double] = _
private var currM2n: Array[Double] = _
private var currM2: Array[Double] = _
private var currL1: Array[Double] = _
private var totalCnt: Long = 0
private var totalWeightSum: Double = 0.0
private var weightSquareSum: Double = 0.0
private var currWeightSum: Array[Double] = _
private var nnz: Array[Long] = _
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _
/**
* Add a new sample to this summarizer, and update the statistical summary.
*
* @param sample The sample in dense/sparse vector format to be added into this summarizer.
* @return This MultivariateOnlineSummarizer object.
*/
@Since("1.1.0")
def add(sample: Vector): this.type = add(sample, 1.0)
private[spark] def add(instance: Vector, weight: Double): this.type = {
require(weight >= 0.0, s"sample weight, ${weight} has to be >= 0.0")
if (weight == 0.0) return this
if (n == 0) {
require(instance.size > 0, s"Vector should have dimension larger than zero.")
n = instance.size
currMean = Array.ofDim[Double](n)
currM2n = Array.ofDim[Double](n)
currM2 = Array.ofDim[Double](n)
currL1 = Array.ofDim[Double](n)
currWeightSum = Array.ofDim[Double](n)
nnz = Array.ofDim[Long](n)
currMax = Array.fill[Double](n)(Double.MinValue)
currMin = Array.fill[Double](n)(Double.MaxValue)
}
require(n == instance.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${instance.size}.")
val localCurrMean = currMean
val localCurrM2n = currM2n
val localCurrM2 = currM2
val localCurrL1 = currL1
val localWeightSum = currWeightSum
val localNumNonzeros = nnz
val localCurrMax = currMax
val localCurrMin = currMin
instance.foreachNonZero { (index, value) =>
if (localCurrMax(index) < value) {
localCurrMax(index) = value
}
if (localCurrMin(index) > value) {
localCurrMin(index) = value
}
val prevMean = localCurrMean(index)
val diff = value - prevMean
localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight)
localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
localCurrM2(index) += weight * value * value
localCurrL1(index) += weight * math.abs(value)
localWeightSum(index) += weight
localNumNonzeros(index) += 1
}
totalWeightSum += weight
weightSquareSum += weight * weight
totalCnt += 1
this
}
/**
* Merge another MultivariateOnlineSummarizer, and update the statistical summary.
* (Note that it's in place merging; as a result, `this` object will be modified.)
*
* @param other The other MultivariateOnlineSummarizer to be merged.
* @return This MultivariateOnlineSummarizer object.
*/
@Since("1.1.0")
def merge(other: MultivariateOnlineSummarizer): this.type = {
if (this.totalWeightSum != 0.0 && other.totalWeightSum != 0.0) {
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
totalWeightSum += other.totalWeightSum
weightSquareSum += other.weightSquareSum
var i = 0
while (i < n) {
val thisNnz = currWeightSum(i)
val otherNnz = other.currWeightSum(i)
val totalNnz = thisNnz + otherNnz
val totalCnnz = nnz(i) + other.nnz(i)
if (totalNnz != 0.0) {
val deltaMean = other.currMean(i) - currMean(i)
// merge mean together
currMean(i) += deltaMean * otherNnz / totalNnz
// merge m2n together
currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
// merge m2 together
currM2(i) += other.currM2(i)
// merge l1 together
currL1(i) += other.currL1(i)
// merge max and min
currMax(i) = math.max(currMax(i), other.currMax(i))
currMin(i) = math.min(currMin(i), other.currMin(i))
}
currWeightSum(i) = totalNnz
nnz(i) = totalCnnz
i += 1
}
} else if (totalWeightSum == 0.0 && other.totalWeightSum != 0.0) {
this.n = other.n
this.currMean = other.currMean.clone()
this.currM2n = other.currM2n.clone()
this.currM2 = other.currM2.clone()
this.currL1 = other.currL1.clone()
this.totalCnt = other.totalCnt
this.totalWeightSum = other.totalWeightSum
this.weightSquareSum = other.weightSquareSum
this.currWeightSum = other.currWeightSum.clone()
this.nnz = other.nnz.clone()
this.currMax = other.currMax.clone()
this.currMin = other.currMin.clone()
}
this
}
/**
* Sample mean of each dimension.
*
*/
@Since("1.1.0")
override def mean: Vector = {
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum)
i += 1
}
Vectors.dense(realMean)
}
/**
* Unbiased estimate of sample variance of each dimension.
*
*/
@Since("1.1.0")
override def variance: Vector = {
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
val realVariance = Array.ofDim[Double](n)
val denominator = totalWeightSum - (weightSquareSum / totalWeightSum)
// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
if (denominator > 0.0) {
val deltaMean = currMean
var i = 0
val len = currM2n.length
while (i < len) {
// We prevent variance from negative value caused by numerical error.
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) *
(totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
Vectors.dense(realVariance)
}
/**
* Sample size.
*
*/
@Since("1.1.0")
override def count: Long = totalCnt
/**
* Sum of weights.
*/
override def weightSum: Double = totalWeightSum
/**
* Number of nonzero elements in each dimension.
*
*/
@Since("1.1.0")
override def numNonzeros: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
Vectors.dense(nnz.map(_.toDouble))
}
/**
* Maximum value of each dimension.
*
*/
@Since("1.1.0")
override def max: Vector = {
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.dense(currMax)
}
/**
* Minimum value of each dimension.
*
*/
@Since("1.1.0")
override def min: Vector = {
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.dense(currMin)
}
/**
* L2 (Euclidean) norm of each dimension.
*
*/
@Since("1.2.0")
override def normL2: Vector = {
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
val realMagnitude = Array.ofDim[Double](n)
var i = 0
val len = currM2.length
while (i < len) {
realMagnitude(i) = math.sqrt(currM2(i))
i += 1
}
Vectors.dense(realMagnitude)
}
/**
* L1 norm of each dimension.
*
*/
@Since("1.2.0")
override def normL1: Vector = {
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
Vectors.dense(currL1)
}
}
相关信息
相关文章
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦