spark Strategy 源码
spark Strategy 代码
文件路径:/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.configuration
import scala.beans.BeanProperty
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
/**
* Stores all the configuration options for tree construction
* @param algo Learning goal. Supported:
* `org.apache.spark.mllib.tree.configuration.Algo.Classification`,
* `org.apache.spark.mllib.tree.configuration.Algo.Regression`
* @param impurity Criterion used for information gain calculation.
* Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]],
* [[org.apache.spark.mllib.tree.impurity.Entropy]].
* Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]].
* @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
* 1 internal node + 2 leaf nodes).
* @param numClasses Number of classes for classification.
* (Ignored for regression.)
* Default value is 2 (binary classification).
* @param maxBins Maximum number of bins used for discretizing continuous features and
* for choosing how to split on features at each node.
* More bins give higher granularity.
* @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported:
* `org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort`
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
* number of discrete values they take. An entry (n to k)
* indicates that feature n is categorical with k categories
* indexed from 0: {0, 1, ..., k-1}.
* @param minInstancesPerNode Minimum number of instances each child must have after split.
* Default value is 1. If a split cause left or right child
* to have less than minInstancesPerNode,
* this split will not be considered as a valid split.
* @param minInfoGain Minimum information gain a split must get. Default value is 0.0.
* If a split has less information gain than minInfoGain,
* this split will not be considered as a valid split.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB. If too small, then 1 node will be split per iteration, and
* its aggregates may exceed this size.
* @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
* maintain a separate RDD of node Id cache for each row.
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
* E.g. 10 means that the cache will get checkpointed every 10 updates. If
* the checkpoint directory is not set in
* [[org.apache.spark.SparkContext]], this setting is ignored.
*/
@Since("1.0.0")
class Strategy @Since("1.3.0") (
@Since("1.0.0") @BeanProperty var algo: Algo,
@Since("1.0.0") @BeanProperty var impurity: Impurity,
@Since("1.0.0") @BeanProperty var maxDepth: Int,
@Since("1.2.0") @BeanProperty var numClasses: Int = 2,
@Since("1.0.0") @BeanProperty var maxBins: Int = 32,
@Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
@Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
@Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1,
@Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0,
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
@Since("3.0.0") @BeanProperty var minWeightFractionPerNode: Double = 0.0,
@BeanProperty private[spark] var bootstrap: Boolean = false) extends Serializable {
/**
*/
@Since("1.2.0")
def isMulticlassClassification: Boolean = {
algo == Classification && numClasses > 2
}
/**
*/
@Since("1.2.0")
def isMulticlassWithCategoricalFeatures: Boolean = {
isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
}
// scalastyle:off argcount
/**
* Backwards compatible constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
*/
@Since("1.0.0")
def this(
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int, Int],
minInstancesPerNode: Int,
minInfoGain: Double,
maxMemoryInMB: Int,
subsamplingRate: Double,
useNodeIdCache: Boolean,
checkpointInterval: Int) = {
this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy,
categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB,
subsamplingRate, useNodeIdCache, checkpointInterval, 0.0)
}
// scalastyle:on argcount
/**
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
*/
@Since("1.1.0")
def this(
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int,
maxBins: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) = {
this(algo, impurity, maxDepth, numClasses, maxBins, Sort,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
minWeightFractionPerNode = 0.0)
}
/**
* Sets Algorithm using a String.
*/
@Since("1.2.0")
def setAlgo(algo: String): Unit = algo match {
case "Classification" => setAlgo(Classification)
case "Regression" => setAlgo(Regression)
}
/**
* Sets categoricalFeaturesInfo using a Java Map.
*/
@Since("1.2.0")
def setCategoricalFeaturesInfo(
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = {
this.categoricalFeaturesInfo =
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap
}
/**
* Check validity of parameters.
* Throws exception if invalid.
*/
private[spark] def assertValid(): Unit = {
algo match {
case Classification =>
require(numClasses >= 2,
s"DecisionTree Strategy for Classification must have numClasses >= 2," +
s" but numClasses = $numClasses.")
require(Set(Gini, Entropy).contains(impurity),
s"DecisionTree Strategy given invalid impurity for Classification: $impurity." +
s" Valid settings: Gini, Entropy")
case Regression =>
require(impurity == Variance,
s"DecisionTree Strategy given invalid impurity for Regression: $impurity." +
s" Valid settings: Variance")
case _ =>
throw new IllegalArgumentException(
s"DecisionTree Strategy given invalid algo parameter: $algo." +
s" Valid settings are: Classification, Regression.")
}
require(maxDepth >= 0, s"DecisionTree Strategy given invalid maxDepth parameter: $maxDepth." +
s" Valid values are integers >= 0.")
require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." +
s" Valid values are integers >= 2.")
require(minInstancesPerNode >= 1,
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
require(maxMemoryInMB <= 10240,
s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
require(subsamplingRate > 0 && subsamplingRate <= 1,
s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " +
s"$subsamplingRate")
}
/**
* Returns a shallow copy of this instance.
*/
@Since("1.2.0")
def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode,
minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache,
checkpointInterval, minWeightFractionPerNode)
}
}
@Since("1.2.0")
object Strategy {
/**
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo "Classification" or "Regression"
*/
@Since("1.2.0")
def defaultStrategy(algo: String): Strategy = {
defaultStrategy(Algo.fromString(algo))
}
/**
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo Algo.Classification or Algo.Regression
*/
@Since("1.3.0")
def defaultStrategy(algo: Algo): Strategy = algo match {
case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
case Algo.Regression =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}
}
相关信息
相关文章
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦