分享

Spark MLlib - Decision Tree源码分析


问题导读

1.org.apache.spark.mllib.tree.RandomForest.scala中RandomForest里面的train做了什么?
2.DecisionTree.findSplitsBins做了什么?








以决策树作为开始,因为简单,而且也比较容易用到,当前的boosting或random forest也是常以其为基础的
决策树算法本身参考之前的blog,其实就是贪婪算法,每次切分使得数据变得最为有序

那么如何来定义有序或无序?
无序,node impurity

081431440091427.png
对于分类问题,我们可以用熵entropy或Gini来表示信息的无序程度
对于回归问题,我们用方差Variance来表示无序程度,方差越大,说明数据间差异越大
information gain
用于表示,由父节点划分后得到子节点,所带来的impurity的下降,即有序性的增益
1.png

MLib决策树的例子
下面直接看个regression的例子,分类的case,差不多,
  1. import org.apache.spark.mllib.tree.DecisionTree
  2. import org.apache.spark.mllib.util.MLUtils
  3. // Load and parse the data file.
  4. // Cache the data since we will use it again to compute training error.
  5. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()
  6. // Train a DecisionTree model.
  7. // Empty categoricalFeaturesInfo indicates all features are continuous.
  8. val categoricalFeaturesInfo = Map[Int, Int]()
  9. val impurity = "variance"
  10. val maxDepth = 5
  11. val maxBins = 100
  12. val model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity,
  13.   maxDepth, maxBins)
  14. // Evaluate model on training instances and compute training error
  15. val labelsAndPredictions = data.map { point =>
  16.   val prediction = model.predict(point.features)
  17.   (point.label, prediction)
  18. }
  19. val trainMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
  20. println("Training Mean Squared Error = " + trainMSE)
  21. println("Learned regression tree model:\n" + model)
复制代码



还是比较简单的,
由于是回归,所以impurity的定义为variance
maxDepth,最大树深,设为5
maxBins,最大的划分数
先理解什么是bin,决策树的算法就是对feature的取值不断的进行划分
对于离散的feature,比较简单,如果有m个值,最多 个划分,如果值是有序的,那么就最多m-1个划分
比如年龄feature,有老,中,少3个值,如果无序有个,即3种划分,老|中,少;老,中|少;老,少|中
但如果是有序的,即按老,中,少的序,那么只有m-1个,即2种划分,老|中,少;老,中|少
对于连续的feature,其实就是进行范围划分,而划分的点就是split,划分出的区间就是bin
对于连续feature,理论上划分点是无数的,但是出于效率我们总要选取合适的划分点
有个比较常用的方法是取出训练集中该feature出现过的值作为划分点,
但对于分布式数据,取出所有的值进行排序也比较费资源,所以可以采取sample的方式

源码分析
首先调用,DecisionTree.trainRegressor,类似调用静态函数(object DecisionTree)
org.apache.spark.mllib.tree.DecisionTree.scala
  1. /**
  2.    * Method to train a decision tree model for regression.
  3.    *
  4.    * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
  5.    *              Labels are real numbers.
  6.    * @param categoricalFeaturesInfo Map storing arity of categorical features.
  7.    *                                E.g., an entry (n -> k) indicates that feature n is categorical
  8.    *                                with k categories indexed from 0: {0, 1, ..., k-1}.
  9.    * @param impurity Criterion used for information gain calculation.
  10.    *                 Supported values: "variance".
  11.    * @param maxDepth Maximum depth of the tree.
  12.    *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
  13.    *                  (suggested value: 5)
  14.    * @param maxBins maximum number of bins used for splitting features
  15.    *                 (suggested value: 32)
  16.    * @return DecisionTreeModel that can be used for prediction
  17.    */
  18.   def trainRegressor(
  19.       input: RDD[LabeledPoint],
  20.       categoricalFeaturesInfo: Map[Int, Int],
  21.       impurity: String,
  22.       maxDepth: Int,
  23.       maxBins: Int): DecisionTreeModel = {
  24.     val impurityType = Impurities.fromString(impurity)
  25.     train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo)
  26.   }
复制代码



调用静态函数train
  1. def train(
  2.       input: RDD[LabeledPoint],
  3.       algo: Algo,
  4.       impurity: Impurity,
  5.       maxDepth: Int,
  6.       numClassesForClassification: Int,
  7.       maxBins: Int,
  8.       quantileCalculationStrategy: QuantileStrategy,
  9.       categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
  10.     val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
  11.       quantileCalculationStrategy, categoricalFeaturesInfo)
  12.     new DecisionTree(strategy).train(input)
  13.   }
复制代码



可以看到将所有参数封装到Strategy类,然后初始化DecisionTree类对象,继续调用成员函数train
  1. /**
  2. * :: Experimental ::
  3. * A class which implements a decision tree learning algorithm for classification and regression.
  4. * It supports both continuous and categorical features.
  5. * @param strategy The configuration parameters for the tree algorithm which specify the type
  6. *                 of algorithm (classification, regression, etc.), feature type (continuous,
  7. *                 categorical), depth of the tree, quantile calculation strategy, etc.
  8. */
  9. @Experimental
  10. class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {
  11.   strategy.assertValid()
  12.   /**
  13.    * Method to train a decision tree model over an RDD
  14.    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
  15.    * @return DecisionTreeModel that can be used for prediction
  16.    */
  17.   def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
  18.     // Note: random seed will not be used since numTrees = 1.
  19.     val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
  20.     val rfModel = rf.train(input)
  21.     rfModel.trees(0)
  22.   }
  23. }
复制代码



可以看到,这里DecisionTree的设计是基于RandomForest的特例,即单颗树的RandomForest
所以调用RandomForest.train(),最终因为只有一棵树,所以取trees(0)

org.apache.spark.mllib.tree.RandomForest.scala
重点看下,RandomForest里面的train做了什么?
  1. /**
  2.    * Method to train a decision tree model over an RDD
  3.    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
  4.    * @return RandomForestModel that can be used for prediction
  5.    */
  6.   def train(input: RDD[LabeledPoint]): RandomForestModel = {
  7.    //1. metadata
  8.     val retaggedInput = input.retag(classOf[LabeledPoint])
  9.     val metadata =
  10.       DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
  11.     // 2. Find the splits and the corresponding bins (interval between the splits) using a sample
  12.     // of the input data.
  13.     val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
  14.     // 3. Bin feature values (TreePoint representation).
  15.     // Cache input RDD for speedup during multiple passes.
  16.     val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
  17.     val baggedInput = if (numTrees > 1) {
  18.       BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed)
  19.     } else {
  20.       BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
  21.     }.persist(StorageLevel.MEMORY_AND_DISK)
  22.     // set maxDepth and compute memory usage
  23.     // depth of the decision tree
  24.     // Max memory usage for aggregates
  25.     // TODO: Calculate memory usage more precisely.
  26.     //........
  27.     /*
  28.      * The main idea here is to perform group-wise training of the decision tree nodes thus
  29.      * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
  30.      * Each data sample is handled by a particular node (or it reaches a leaf and is not used
  31.      * in lower levels).
  32.      */
  33.     // FIFO queue of nodes to train: (treeIndex, node)
  34.     val nodeQueue = new mutable.Queue[(Int, Node)]()
  35.     val rng = new scala.util.Random()
  36.     rng.setSeed(seed)
  37.     // Allocate and queue root nodes.
  38.     val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
  39.     Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
  40.     while (nodeQueue.nonEmpty) {
  41.       // Collect some nodes to split, and choose features for each node (if subsampling).
  42.       // Each group of nodes may come from one or multiple trees, and at multiple levels.
  43.       val (nodesForGroup, treeToNodeToIndexInfo) =
  44.         RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) // 对decision tree没有意义,nodeQueue只有一个node,不需要选
  45.       // 4. Choose node splits, and enqueue new nodes as needed.
  46.       DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
  47.         treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
  48.     }
  49.     val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
  50.     RandomForestModel.build(trees)
  51.   }
复制代码



1. DecisionTreeMetadata.buildMetadata
org.apache.spark.mllib.tree.impl.DecisionTreeMetadata.scala
这里生成一些后面需要用到的metadata
最关键的是计算每个feature的bins和splits的数目,
计算bins的数目
  1.   //bins数目最大不能超过训练集中样本的size
  2.     val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
  3.     //设置默认值
  4.     val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
  5.     if (numClasses > 2) {
  6.       // Multiclass classification
  7.       val maxCategoriesForUnorderedFeature =
  8.         ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
  9.       strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
  10.         // Decide if some categorical features should be treated as unordered features,
  11.         //  which require 2 * ((1 << numCategories - 1) - 1) bins.
  12.         // We do this check with log values to prevent overflows in case numCategories is large.
  13.         // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
  14.         if (numCategories <= maxCategoriesForUnorderedFeature) {
  15.           unorderedFeatures.add(featureIndex)
  16.           numBins(featureIndex) = numUnorderedBins(numCategories)
  17.         } else {
  18.           numBins(featureIndex) = numCategories
  19.         }
  20.       }
  21.     } else {
  22.       // Binary classification or regression
  23.       strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
  24.         numBins(featureIndex) = numCategories
  25.       }
  26.     }
复制代码



其他case,bins数目等于feature的numCategories
对于unordered情况,比较特殊,
  1. /**
  2.    * Given the arity of a categorical feature (arity = number of categories),
  3.    * return the number of bins for the feature if it is to be treated as an unordered feature.
  4.    * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
  5.    * there are math.pow(2, arity - 1) - 1 such splits.
  6.    * Each split has 2 corresponding bins.
  7.    */
  8.   def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
复制代码


根据bins数目,计算splits
  1. /**
  2.    * Number of splits for the given feature.
  3.    * For unordered features, there are 2 bins per split.
  4.    * For ordered features, there is 1 more bin than split.
  5.    */
  6.   def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
  7.     numBins(featureIndex) >> 1
  8.   } else {
  9.     numBins(featureIndex) - 1
  10.   }
复制代码



2. DecisionTree.findSplitsBins
首先找出每个feature上可能出现的splits和相应的bins,这是后续算法的基础
这里的注释解释了上面如何计算splits和bins数目的算法
a,对于连续数据,对于一个feature,splits = numBins - 1;上面也说了对于连续值,其实splits可以无限的,如何找到numBins - 1个splits,很简单,这里用sample
b,对于离散数据,两个case
    b.1, 无序的feature,用于low-arity(参数较少)的multiclass分类,这种case下划分的可能性比较多,,所以用subsets of categories来作为划分
    b.2, 有序的feature,用于regression,二元分类,或high-arity的多元分类,这种case下划分的可能比较少,m-1,所以用每个category作为划分
  1. /**
  2.    * Returns splits and bins for decision tree calculation.
  3.    * Continuous and categorical features are handled differently.
  4.    *
  5.    * Continuous features:
  6.    *   For each feature, there are numBins - 1 possible splits representing the possible binary
  7.    *   decisions at each node in the tree.
  8.    *   This finds locations (feature values) for splits using a subsample of the data.
  9.    *
  10.    * Categorical features:
  11.    *   For each feature, there is 1 bin per split.
  12.    *   Splits and bins are handled in 2 ways:
  13.    *   (a) "unordered features"
  14.    *       For multiclass classification with a low-arity feature
  15.    *       (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
  16.    *       the feature is split based on subsets of categories.
  17.    *   (b) "ordered features"
  18.    *       For regression and binary classification,
  19.    *       and for multiclass classification with a high-arity feature,
  20.    *       there is one bin per category.
  21.    *
  22.    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
  23.    * @param metadata Learning and dataset metadata
  24.    * @return A tuple of (splits, bins).
  25.    *         Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
  26.    *          of size (numFeatures, numSplits).
  27.    *         Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
  28.    *          of size (numFeatures, numBins).
  29.    */
  30.   protected[tree] def findSplitsBins(
  31.       input: RDD[LabeledPoint],
  32.       metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
  33.     val numFeatures = metadata.numFeatures
  34.     // Sample the input only if there are continuous features.
  35.     val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
  36.     val sampledInput = if (hasContinuousFeatures) {  // 对于连续特征,取值会比较多,需要做抽样
  37.       // Calculate the number of samples for approximate quantile calculation.
  38.       val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) // 抽样数要远大于桶数
  39.       val fraction = if (requiredSamples < metadata.numExamples) { // 设置抽样比例
  40.         requiredSamples.toDouble / metadata.numExamples
  41.       } else {
  42.         1.0
  43.       }
  44.       input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
  45.     } else {
  46.       new Array[LabeledPoint](0)
  47.     }
  48.     metadata.quantileStrategy match {
  49.       case Sort =>
  50.         val splits = new Array[Array[Split]](numFeatures) // 初始化splits和bins
  51.         val bins = new Array[Array[Bin]](numFeatures)
  52.         // Find all splits.
  53.         // Iterate over all features.
  54.         var featureIndex = 0
  55.         while (featureIndex < numFeatures) { // 遍历所有的feature
  56.           val numSplits = metadata.numSplits(featureIndex) // 取出前面算出的splits和bins的数目
  57.           val numBins = metadata.numBins(featureIndex)
  58.           if (metadata.isContinuous(featureIndex)) { // 对于连续的feature
  59.             val numSamples = sampledInput.length
  60.             splits(featureIndex) = new Array[Split](numSplits)
  61.             bins(featureIndex) = new Array[Bin](numBins)
  62.             val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted // 从sampledInput里面取出该feature的所有取值,排序
  63.             val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex) // 取样数/桶数,决定split(划分)的步长
  64.             logDebug("stride = " + stride)
  65.             for (splitIndex <- 0 until numSplits) { // 开始划分
  66.               val sampleIndex = splitIndex * stride.toInt // 划分数×步长,得到划分所用的sample的index
  67.               // Set threshold halfway in between 2 samples.
  68.               val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 // 划分点选取在前后两个sample的均值
  69.               splits(featureIndex)(splitIndex) =
  70.                 new Split(featureIndex, threshold, Continuous, List()) // 创建Split对象
  71.             }
  72.             bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), // 初始化第一个split,DummyLowSplit,取值是Double.MinValue
  73.               splits(featureIndex)(0), Continuous, Double.MinValue)
  74.             for (splitIndex <- 1 until numSplits) { // 创建所有的bins
  75.               bins(featureIndex)(splitIndex) =
  76.                 new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
  77.                   Continuous, Double.MinValue)
  78.             }
  79.             bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1), // 初始化最后一个split,DummyHighSplit,取值是Double.MaxValue
  80.               new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
  81.           } else { // 对于分类的feature
  82.             // Categorical feature
  83.             val featureArity = metadata.featureArity(featureIndex) // 离散特征中的取值个数
  84.             if (metadata.isUnordered(featureIndex)) { // 无序的离散特征
  85.               // TODO: The second half of the bins are unused.  Actually, we could just use
  86.               //       splits and not build bins for unordered features.  That should be part of
  87.               //       a later PR since it will require changing other code (using splits instead
  88.               //       of bins in a few places).
  89.               // Unordered features
  90.               //   2^(maxFeatureValue - 1) - 1 combinations
  91.               splits(featureIndex) = new Array[Split](numSplits)
  92.               bins(featureIndex) = new Array[Bin](numBins)
  93.               var splitIndex = 0
  94.               while (splitIndex < numSplits) {
  95.                 val categories: List[Double] =
  96.                   extractMultiClassCategories(splitIndex + 1, featureArity)
  97.                 splits(featureIndex)(splitIndex) =
  98.                   new Split(featureIndex, Double.MinValue, Categorical, categories)
  99.                 bins(featureIndex)(splitIndex) = {
  100.                   if (splitIndex == 0) {
  101.                     new Bin(
  102.                       new DummyCategoricalSplit(featureIndex, Categorical),
  103.                       splits(featureIndex)(0),
  104.                       Categorical,
  105.                       Double.MinValue)
  106.                   } else {
  107.                     new Bin(
  108.                       splits(featureIndex)(splitIndex - 1),
  109.                       splits(featureIndex)(splitIndex),
  110.                       Categorical,
  111.                       Double.MinValue)
  112.                   }
  113.                 }
  114.                 splitIndex += 1
  115.               }
  116.             } else { // 有序的离散特征,不需要事先算,因为splits就等于featureArity
  117.               // Ordered features
  118.               //   Bins correspond to feature values, so we do not need to compute splits or bins
  119.               //   beforehand.  Splits are constructed as needed during training.
  120.               splits(featureIndex) = new Array[Split](0)
  121.               bins(featureIndex) = new Array[Bin](0)
  122.             }
  123.           }
  124.           featureIndex += 1
  125.         }
  126.         (splits, bins)
  127.       case MinMax =>
  128.         throw new UnsupportedOperationException("minmax not supported yet.")
  129.       case ApproxHist =>
  130.         throw new UnsupportedOperationException("approximate histogram not supported yet.")
  131.     }
  132.   }
复制代码



3. TreePoint和BaggedPoint
TreePoint是LabeledPoint的内部数据结构,这里需要做转换,
  1. private def labeledPointToTreePoint(
  2.       labeledPoint: LabeledPoint,
  3.       bins: Array[Array[Bin]],
  4.       featureArity: Array[Int],
  5.       isUnordered: Array[Boolean]): TreePoint = {
  6.     val numFeatures = labeledPoint.features.size
  7.     val arr = new Array[Int](numFeatures)
  8.     var featureIndex = 0
  9.     while (featureIndex < numFeatures) {
  10.       arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
  11.         isUnordered(featureIndex), bins)
  12.       featureIndex += 1
  13.     }
  14.     new TreePoint(labeledPoint.label, arr)  //只是将labeledPoint中的value替换成arr
  15.   }
复制代码


arr是findBin的结果,
这里主要是针对连续特征做处理,将连续的值通过二分查找转换为相应bin的index
对于离散数据,bin等同于featureValue.toInt
BaggedPoint,由于random forest是比较典型的bagging算法,所以需要对训练集做bootstrap sample
而对于decision tree是特殊的单根random forest,所以不需要做抽样
BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
其实只是做简单的封装

4. DecisionTree.findBestSplits
这段代码写的有点复杂,尤其和randomForest混杂一起
总之,关键在
  1. // find best split for each node
  2.           val (split: Split, stats: InformationGainStats, predict: Predict) =
  3.             binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
  4.           (nodeIndex, (split, stats, predict))
  5.         }.collectAsMap()
复制代码



看看binsToBestSplit的实现,为了清晰一点,我们只看continuous feature
四个参数,
binAggregates: DTStatsAggregator, 就是ImpurityAggregator,给出如果算出impurity的逻辑
splits: Array[Array[Split]], feature对应的splits
featuresForNode: Option[Array[Int]], tree node对应的feature  
node: Node, 哪个tree node
返回值,
(Split, InformationGainStats, Predict),
Split,最优的split对象(包含featureindex和splitindex)
InformationGainStats,该split产生的Gain对象,表明产生多少增益,多大程度降低impurity
Predict,该节点的预测值,对于连续feature就是平均值,看后面的分析
  1. private def binsToBestSplit(
  2.       binAggregates: DTStatsAggregator,
  3.       splits: Array[Array[Split]],
  4.       featuresForNode: Option[Array[Int]],
  5.       node: Node): (Split, InformationGainStats, Predict) = {
  6.     // For each (feature, split), calculate the gain, and select the best (feature, split).
  7.     val (bestSplit, bestSplitStats) =
  8.       Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>  //遍历每个feature
  9.        //......取出feature对应的splits   
  10.         // Find best split.
  11.         val (bestFeatureSplitIndex, bestFeatureGainStats) =
  12.           Range(0, numSplits).map { case splitIdx =>    //遍历每个splits
  13.             val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
  14.             val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
  15.             rightChildStats.subtract(leftChildStats)
  16.             predictWithImpurity = Some(predictWithImpurity.getOrElse(
  17.               calculatePredictImpurity(leftChildStats, rightChildStats)))
  18.             val gainStats = calculateGainForSplit(leftChildStats,    //算出gain,InformationGainStats对象
  19.               rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
  20.             (splitIdx, gainStats)
  21.           }.maxBy(_._2.gain)    //找到gain最大的split的index
  22.         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
  23.       }
  24.       //......省略离散特征的case
  25.     }.maxBy(_._2.gain) //找到gain最大的feature的split
  26.     (bestSplit, bestSplitStats, predictWithImpurity.get._1)
  27.   }
复制代码


Predict,这个需要分析一下
predictWithImpurity.get._1,predictWithImpurity元组的第一个元素
calculatePredictImpurity的返回值中的predict
  1. private def calculatePredictImpurity(
  2.       leftImpurityCalculator: ImpurityCalculator,
  3.       rightImpurityCalculator: ImpurityCalculator): (Predict, Double) =  {
  4.     val parentNodeAgg = leftImpurityCalculator.copy
  5.     parentNodeAgg.add(rightImpurityCalculator)
  6.     val predict = calculatePredict(parentNodeAgg)
  7.     val impurity = parentNodeAgg.calculate()
  8.     (predict, impurity)
  9.   }
复制代码


  1. private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
  2.    val predict = impurityCalculator.predict
  3.     val prob = impurityCalculator.prob(predict)
  4.     new Predict(predict, prob)
  5.   }
复制代码


这里predict和impurity有什么不同,可以看出
impurity = ImpurityCalculator.calculate()
predict = ImpurityCalculator.predict
对于连续feature,我们就看Variance的实现,
  1. /**
  2.    * Calculate the impurity from the stored sufficient statistics.
  3.    */
  4.   def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
复制代码


  1. @DeveloperApi
  2.   override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
  3.     if (count == 0) {
  4.       return 0
  5.     }
  6.     val squaredLoss = sumSquares - (sum * sum) / count
  7.     squaredLoss / count
  8.   }
复制代码





从calculate的实现可以看到,impurity求的就是均方差
  1. /**
  2.    * Prediction which should be made based on the sufficient statistics.
  3.    */
  4.   def predict: Double = if (count == 0) {
  5.     0
  6.   } else {
  7.     stats(1) / count
  8.   }
复制代码



而predict求的就是平均值





欢迎加入about云群371358502、39327136,云计算爱好者群,亦可关注about云腾讯认证空间||关注本站微信

没找到任何评论,期待你打破沉寂

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

推荐上一条 /2 下一条