分享

Apache Spark源码走读之23 -- Spark MLLib中拟牛顿法L-BFGS的源码实现

本帖最后由 pig2 于 2015-1-6 14:18 编辑
问题导读
1.牛顿法有哪些优点体现?
2.L-BFGS算法中使用到的正则化方法是什么?















概要
本文就拟牛顿法L-BFGS的由来做一个简要的回顾,然后就其在spark mllib中的实现进行源码走读。

拟牛顿法数学原理



251942413766872.png

251942527663054.png

251943054239018.png

251943165489184.png
251943296574923.png
251943430956462.png
251943584703860.png
251944100325624.png
251944255486064.png

代码实现
L-BFGS算法中使用到的正则化方法是SquaredL2Updater。
算法实现上使用到了由scalanlp的成员项目breeze库中的BreezeLBFGS函数,mllib中自定义了BreezeLBFGS所需要的DiffFunctions.
251945451886496.png

runLBFGS函数的源码实现如下
  1. def
  2. runLBFGS(
  3.       data: RDD[(Double, Vector)],
  4.       gradient: Gradient,
  5.       updater: Updater,
  6.       numCorrections: Int,
  7.       convergenceTol: Double,
  8.       maxNumIterations: Int,
  9.       regParam: Double,
  10.       initialWeights: Vector): (Vector, Array[Double]) = {
  11.    
  12. val lossHistory = new ArrayBuffer[Double](maxNumIterations)
  13.     val numExamples = data.count()
  14.     val costFun =
  15.       new CostFun(data, gradient, updater, regParam, numExamples)
  16.     val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol)
  17.     val states =
  18.       lbfgs.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
  19.     /**
  20.      * NOTE: lossSum and loss is computed using the weights from the previous iteration
  21.      * and regVal is the regularization value computed in the previous iteration as well.
  22.      */
  23.     var state = states.next()
  24.     while(states.hasNext) {
  25.       lossHistory.append(state.value)
  26.       state = states.next()
  27.     }
  28.     lossHistory.append(state.value)
  29.     val weights = Vectors.fromBreeze(state.x)
  30.     logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format(
  31.       lossHistory.takeRight(10).mkString(", ")))
  32.     (weights, lossHistory.toArray)
  33.   }
复制代码

costFun函数是算法实现中的重点
  1. private
  2. class CostFun(
  3.     data: RDD[(Double, Vector)],
  4.     gradient: Gradient,
  5.     updater: Updater,
  6.     regParam: Double,
  7.     numExamples: Long) extends DiffFunction[BDV[Double]] {
  8.     private var i = 0
  9.     override def calculate(weights: BDV[Double]) = {
  10.       // Have a local copy to avoid the serialization of CostFun object which is not serializable.
  11.       val localData = data
  12.       val localGradient = gradient
  13.       val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
  14.           seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
  15.             val l = localGradient.compute(
  16.               features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
  17.             (grad, loss + l)
  18.           },
  19.           combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
  20.             (grad1 += grad2, loss1 + loss2)
  21.           })
  22.       /**
  23.        * regVal is sum of weight squares if it's L2 updater;
  24.        * for other updater, the same logic is followed.
  25.        */
  26.       val regVal = updater.compute(
  27.         Vectors.fromBreeze(weights),
  28.         Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
  29.       val loss = lossSum / numExamples + regVal
  30.       /**
  31.        * It will return the gradient part of regularization using updater.
  32.        *
  33.        * Given the input parameters, the updater basically does the following,
  34.        *
  35.        * w' = w - thisIterStepSize * (gradient + regGradient(w))
  36.        * Note that regGradient is function of w
  37.        *
  38.        * If we set gradient = 0, thisIterStepSize = 1, then
  39.        *
  40.        * regGradient(w) = w - w'
  41.        *
  42.        * TODO: We need to clean it up by separating the logic of regularization out
  43.        *       from updater to regularizer.
  44.        */
  45.       // The following gradientTotal is actually the regularization part of gradient.
  46.       // Will add the gradientSum computed from the data with weights in the next step.
  47.       val gradientTotal = weights - updater.compute(
  48.         Vectors.fromBreeze(weights),
  49.         Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze
  50.       // gradientTotal = gradientSum / numExamples + gradientTotal
  51.       axpy(1.0 / numExamples, gradientSum, gradientTotal)
  52.       i += 1
  53.       (loss, gradientTotal)
  54.     }
  55.   }
  56. }
复制代码





相关内容


Apache Spark源码走读之1 -- Spark论文阅读笔记

Apache Spark源码走读之2 -- Job的提交与运行

Apache Spark源码走读之3-- Task运行期之函数调用关系分析

Apache Spark源码走读之4 -- DStream实时流数据处理

Apache Spark源码走读之5-- DStream处理的容错性分析

Apache Spark源码走读之6-- 存储子系统分析

Apache Spark源码走读之7 -- Standalone部署方式分析

Apache Spark源码走读之8 -- Spark on Yarn

Apache Spark源码走读之9 -- Spark源码编译

Apache Spark源码走读之10 -- 在YARN上运行SparkPi

Apache Spark源码走读之11 -- sql的解析与执行

Apache Spark源码走读之12 -- Hive on Spark运行环境搭建

Apache Spark源码走读之13 -- hiveql on spark实现详解

Apache Spark源码走读之14 -- Graphx实现剖析

Apache Spark源码走读之15 -- Standalone部署模式下的容错性分析

Apache Spark源码走读之16 -- spark repl实现详解

Apache Spark源码走读之17 -- 如何进行代码跟读

Apache Spark源码走读之18 -- 使用Intellij idea调试Spark源码

Apache Spark源码走读之19 -- standalone cluster模式下资源的申请与释放

Apache Spark源码走读之20 -- ShuffleMapTask计算结果的保存与读取

Apache Spark源码走读之21 -- WEB UI和Metrics初始化及数据更新过程分析

Apache Spark源码走读之22 -- 浅谈mllib中线性回归的算法实现

Apache Spark源码走读之23 -- Spark MLLib中拟牛顿法L-BFGS的源码实现

Apache Spark源码走读之24 -- Sort-based Shuffle的设计与实现





本文转自徽沪一郎http://www.cnblogs.com/hseagle/p/3927887.html
欢迎加入about云群90371779322273151432264021 ,云计算爱好者群,亦可关注about云腾讯认证空间||关注本站微信

已有(1)人评论

跳转到指定楼层
355815741 发表于 2015-1-5 09:50:14
学习了,谢谢lz分享~
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

推荐上一条 /2 下一条