博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Spark MLlib 之 aggregate和treeAggregate从原理到应用
阅读量:6450 次
发布时间:2019-06-23

本文共 5176 字,大约阅读时间需要 17 分钟。

在阅读spark mllib源码的时候,发现一个出镜率很高的函数——aggregate和treeAggregate,比如matrix.columnSimilarities()中。为了好好理解这两个方法的使用,于是整理了本篇内容。

由于treeAggregate是在aggregate基础上的优化版本,因此先来看看aggregate是什么.

更多内容参考

aggregate

先直接看一下代码例子:

import org.apache.spark.sql.SparkSessionobject AggregateTest {  def main(args: Array[String]): Unit = {    val spark = SparkSession.builder().master("local[*]").appName("tf-idf").getOrCreate()    spark.sparkContext.setLogLevel("WARN")    // 创建rdd,并分成6个分区    val rdd = spark.sparkContext.parallelize(1 to 12).repartition(6)    // 输出每个分区的内容    rdd.mapPartitionsWithIndex((index:Int,it:Iterator[Int])=>{      Array((s" $index : ${it.toList.mkString(",")}")).toIterator    }).foreach(println)    // 执行agg    val res1 = rdd.aggregate(0)(seqOp, combOp)  }  // 分区内执行的方法,直接加和  def seqOp(s1:Int, s2:Int):Int = {    println("seq: "+s1+":"+s2)    s1 + s2  }  // 在driver端汇总  def combOp(c1: Int, c2: Int): Int = {    println("comb: "+c1+":"+c2)    c1 + c2  }}

这段代码的主要目的就是为了求和。考虑到spark分区并行计算的特性,在每个分区独立加和,最后再汇总加和。

过程可以参考下面的图片:

o_agg.jpg

首先看一下map阶段,即在每个分区内计算加和。初始情况如蓝色方块所示,内容为:

分区号:里面的内容如,0分区内的数据为6和8

当执行seqop时,会说先用初始值0开始遍历累加,原理类似如下:

rdd.mapPartitions((it:Iterator)=>{    var sum = init_value // 默认为0    it.foreach(sum + _)    sum})

因此屏幕上会出现下面的内容,由于分区之间是并行的,所以最后的结果是乱序的:

seq: 0:6seq: 0:1seq: 0:3seq: 1:9seq: 3:10seq: 0:2seq: 0:5seq: 5:7seq: 12:12seq: 0:4seq: 4:11seq: 6:8

计算完成后,依次遍历每个分区结果,进行累加:

comb: 0:10comb: 10:13comb: 23:2comb: 25:24comb: 49:15comb: 64:14

aggregate的源码也比较简单:

def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope {    var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance())    val cleanSeqOp = sc.clean(seqOp)    val cleanCombOp = sc.clean(combOp)    val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)    val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)    sc.runJob(this, aggregatePartition, mergeResult)    jobResult  }

treeAggregate

treeAggregate在aggregate的基础上做了一些优化,因为aggregate是在每个分区计算完成后,把所有的数据拉倒driver端,进行统一的遍历合并,这样如果数据量很大,在driver端可能会OOM。

因此treeAggregate在中间多加了一层合并。

先来看看代码,没有任何的变化:

import org.apache.spark.sql.SparkSessionobject TreeAggregateTest {  def main(args: Array[String]): Unit = {    val spark = SparkSession.builder().master("local[*]").appName("tf-idf").getOrCreate()    spark.sparkContext.setLogLevel("WARN")    val rdd = spark.sparkContext.parallelize(1 to 12).repartition(6)    rdd.mapPartitionsWithIndex((index:Int,it:Iterator[Int])=>{      Array(s" $index : ${it.toList.mkString(",")}").toIterator    }).foreach(println)    val res1 = rdd.treeAggregate(0)(seqOp, combOp)    println(res1)  }  def seqOp(s1:Int, s2:Int):Int = {    println("seq: "+s1+":"+s2)    s1 + s2  }  def combOp(c1: Int, c2: Int): Int = {    println("comb: "+c1+":"+c2)    c1 + c2  }}

输出的结果则发生了变化,首先分区内的操作不变:

3 : 3,10 2 : 2 0 : 6,8 1 : 1,9 4 : 4,11 5 : 5,7,12seq: 0:3seq: 0:6seq: 3:10seq: 6:8seq: 0:2seq: 0:1seq: 1:9seq: 0:4seq: 4:11seq: 0:5seq: 5:7seq: 12:12...

在合并的时候发生了 变化:

comb: 10:13comb: 23:24comb: 14:2comb: 16:15comb: 47:31

配合下面的流程图,可以更好的理解:

o_treeAgg.jpg
搭配treeAggregate的源码来看一下:

def treeAggregate[U: ClassTag](zeroValue: U)(      seqOp: (U, T) => U,      combOp: (U, U) => U,      depth: Int = 2): U = withScope {    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")    if (partitions.length == 0) {      Utils.clone(zeroValue, context.env.closureSerializer.newInstance())    } else {      // 这里都没什么变化,在分区中遍历数据累加      val cleanSeqOp = context.clean(seqOp)      val cleanCombOp = context.clean(combOp)      val aggregatePartition =        (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)      var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))      // 关键是这下面的内容 !!!!      // 首先获得当前的分区数      var numPartitions = partiallyAggregated.partitions.length      // 计算合适的并行度,我这里相当于6^(1/2),也就是2.4左右,ceill向上取整后变成3.      // max(3,2)得到最后的结果为3。即每个树的分枝有3个叶子节点      val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)            // 遍历分区,通过对scale取模进行合并计算      // 这里判断一下,当前的分区数是否还够分。如果少于条件值 scale+(p/scale),就停止分区      while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {        numPartitions /= scale        val curNumPartitions = numPartitions        // 重新定义分区id,并按照分区id重新分区,执行合并计算        partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {          (i, iter) => iter.map((i % curNumPartitions, _))        }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values      }      // 最后统计结果      partiallyAggregated.reduce(cleanCombOp)    }  }

spark中的应用

// matrix求相似度def columnSimilarities(threshold: Double): CoordinateMatrix = {...              columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma)}// 统计每一个向量的相关数据,里面包含了min max 等等很多信息def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {  val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(    (aggregator, data) => aggregator.add(data),    (aggregator1, aggregator2) => aggregator1.merge(aggregator2))  updateNumRows(summary.count)  summary}

了解了treeAggregate之后,后续就可以看matrix的并行求解相似度的源码了!敬请期待吧...

参考

转载地址:http://rflwo.baihongyu.com/

你可能感兴趣的文章
Linux安装与基础命令
查看>>
quick UIInput 的使用
查看>>
远程管理
查看>>
免费(学习)使用的数据库
查看>>
锁定10月10日,九州云Animbus7.0与你不见不散
查看>>
事务 C#中TransactionScope的使用方法和原理
查看>>
如何在NGINX中重定向一个网址(301 跳转)
查看>>
CentOS 6.4 关闭 selinux
查看>>
Android数据存储方式
查看>>
使用 Fonty Python 管理你的字体
查看>>
VMware Workstation 14 Pro黑屏
查看>>
《Effective Java》2nd 笔记
查看>>
The connection to adb is down, and a severe error has occured.
查看>>
什么是重构,什么不是重构
查看>>
我的友情链接
查看>>
Python中用matplotlib.pyplot画图总结
查看>>
图解css3:核心技术与案例实战
查看>>
NLB+DFS实现高可靠性WEB服务器
查看>>
Java8中Lambda表达式的10个例子
查看>>
sublime text 2 实现java编辑器环境
查看>>