spark DTStatsAggregator 源码
spark DTStatsAggregator 代码
文件路径:/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.tree.impl
import org.apache.spark.mllib.tree.impurity._
/**
* DecisionTree statistics aggregator for a node.
* This holds a flat array of statistics for a set of (features, bins)
* and helps with indexing.
* This class is abstract to support learning with and without feature subsampling.
*/
private[spark] class DTStatsAggregator(
val metadata: DecisionTreeMetadata,
featureSubset: Option[Array[Int]]) extends Serializable {
/**
* [[ImpurityAggregator]] instance specifying the impurity type.
*/
val impurityAggregator: ImpurityAggregator = metadata.impurity match {
case Gini => new GiniAggregator(metadata.numClasses)
case Entropy => new EntropyAggregator(metadata.numClasses)
case Variance => new VarianceAggregator()
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
}
/**
* Number of elements (Double values) used for the sufficient statistics of each bin.
*/
private val statsSize: Int = impurityAggregator.statsSize
/**
* Number of bins for each feature. This is indexed by the feature index.
*/
private val numBins: Array[Int] = {
if (featureSubset.isDefined) {
featureSubset.get.map(metadata.numBins(_))
} else {
metadata.numBins
}
}
/**
* Offset for each feature for calculating indices into the [[allStats]] array.
*/
private val featureOffsets: Array[Int] = {
numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
}
/**
* Total number of elements stored in this aggregator
*/
private val allStatsSize: Int = featureOffsets.last
/**
* Flat array of elements.
* Index for start of stats for a (feature, bin) is:
* index = featureOffsets(featureIndex) + binIndex * statsSize
*/
private val allStats: Array[Double] = new Array[Double](allStatsSize)
/**
* Array of parent node sufficient stats.
* Note: parent stats need to be explicitly tracked in the [[DTStatsAggregator]] for unordered
* categorical features, because the parent [[Node]] object does not have [[ImpurityStats]]
* on the first iteration.
*/
private val parentStats: Array[Double] = new Array[Double](statsSize)
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
*
* @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
*/
def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
}
/**
* Get an [[ImpurityCalculator]] for the parent node.
*/
def getParentImpurityCalculator(): ImpurityCalculator = {
impurityAggregator.getCalculator(parentStats, 0)
}
/**
* Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
def update(
featureIndex: Int,
binIndex: Int,
label: Double,
numSamples: Int,
sampleWeight: Double): Unit = {
val i = featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label, numSamples, sampleWeight)
}
/**
* Update the parent node stats using the given label.
*/
def updateParent(label: Double, numSamples: Int, sampleWeight: Double): Unit = {
impurityAggregator.update(parentStats, 0, label, numSamples, sampleWeight)
}
/**
* Faster version of [[update]].
* Update the stats for a given (feature, bin), using the given label.
*
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
*/
def featureUpdate(
featureOffset: Int,
binIndex: Int,
label: Double,
numSamples: Int,
sampleWeight: Double): Unit = {
impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
label, numSamples, sampleWeight)
}
/**
* Pre-compute feature offset for use with [[featureUpdate]].
* For ordered features only.
*/
def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
/**
* For a given feature, merge the stats for two bins.
*
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
* @param binIndex The other bin is merged into this bin.
* @param otherBinIndex This bin is not modified.
*/
def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize,
featureOffset + otherBinIndex * statsSize)
}
/**
* Merge this aggregator with another, and returns this aggregator.
* This method modifies this aggregator in-place.
*/
def merge(other: DTStatsAggregator): DTStatsAggregator = {
require(allStatsSize == other.allStatsSize,
s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
+ s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
var i = 0
// TODO: Test BLAS.axpy
while (i < allStatsSize) {
allStats(i) += other.allStats(i)
i += 1
}
require(statsSize == other.statsSize,
s"DTStatsAggregator.merge requires that both aggregators have the same length parent " +
s"stats vectors. This aggregator's parent stats are length $statsSize, " +
s"but the other is ${other.statsSize}.")
var j = 0
while (j < statsSize) {
parentStats(j) += other.parentStats(j)
j += 1
}
this
}
}
相关信息
相关文章
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
7、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦