nettman 发表于 2020-8-17 20:35:01

Spark SQL 项目:实现各区域热门商品前N统计

一. 需求

1.1 需求简介
这里的热门商品是从点击量的维度来看的.

计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。




1.2 思路分析
使用 sql 来完成. 碰到复杂的需求, 可以使用 udf 或 udaf

查询出来所有的点击记录, 并与 city_info 表连接, 得到每个城市所在的地区. 与 Product_info 表连接得到产品名称
按照地区和商品 id 分组, 统计出每个商品在每个地区的总点击次数
每个地区内按照点击次数降序排列
只取前三名. 并把结果保存在数据库中
城市备注需要自定义 UDAF 函数



二. 实际操作
1. 准备数据
  我们这次 Spark-sql 操作中所有的数据均来自 Hive.

  首先在 Hive 中创建表, 并导入数据.

  一共有 3 张表: 1 张用户行为表, 1 张城市表, 1 张产品表

1. 打开Hive




2. 创建三个表
CREATE TABLE `user_visit_action`(
`date` string,
`user_id` bigint,
`session_id` string,
`page_id` bigint,
`action_time` string,
`search_keyword` string,
`click_category_id` bigint,
`click_product_id` bigint,
`order_category_ids` string,
`order_product_ids` string,
`pay_category_ids` string,
`pay_product_ids` string,
`city_id` bigint)
row format delimited fields terminated by '\t';

CREATE TABLE `product_info`(
`product_id` bigint,
`product_name` string,
`extend_info` string)
row format delimited fields terminated by '\t';

CREATE TABLE `city_info`(
`city_id` bigint,
`city_name` string,
`area` string)
row format delimited fields terminated by '\t';





3. 上传数据

load data local inpath '/opt/module/datas/user_visit_action.txt' into table spark0806.user_visit_action;
load data local inpath '/opt/module/datas/product_info.txt' into table spark0806.product_info;
load data local inpath '/opt/module/datas/city_info.txt' into table spark0806.city_info;




4. 测试是否上传成功

hive> select * from city_info;




2. 显示各区域热门商品 Top3


// user_visit_actionproduct_infocity_info

1. 先把需要的字段查出来   t1
select
    ci.*,
    pi.product_name,
    click_product_id
from user_visit_action uva
join product_info pi on uva.click_product_id=pi.product_id
join city_info ci on uva.city_id=ci.city_id

2. 按照地区和商品名称聚合
select
    area,
    product_name,
    count(*)count
from t1
group by area , product_name

3. 按照地区进行分组开窗 排序 开窗函数 t3 // (rank(1 2 2 4 5...) row_number(1 2 3 4...) dense_rank(1 2 2 3 4...))
select
    area,
    product_name,
    count,
    rank() over(partition by area order by count desc)
fromt2


4. 过滤出来名次小于等于3的
select
    area,
    product_name,
    count
fromt3
where rk <=3


2. 运行结果




3. 定义udaf函数 得到需求结果

package com.buwenbuhuo.spark.sql.project

import java.text.DecimalFormat

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


/**
**
*
* @author 不温卜火
*         *
* @create 2020-08-06 13:24
**
*         MyCSDN :https://buwenbuhuo.blog.csdn.net/
*
*/
class CityRemarkUDAF extends UserDefinedAggregateFunction {
// 输入数据的类型:北京String
override def inputSchema: StructType = {
    StructType(Array(StructField("city", StringType)))
}

// 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000Map,总的点击量1000/?
override def bufferSchema: StructType = {
    StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
}

// 输出的数据类型"北京21.2%,天津13.2%,其他65.6%"String
override def dataType: DataType = StringType

// 相同的输入是否应用有相同的输出.
override def deterministic: Boolean = true

// 给存储数据初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化map缓存
    buffer(0) = Map()
    // 初始化总的点击量
    buffer(1) = 0L
}

// 分区内合并 Map[城市名, 点击量]
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(cityName: String) =>
      // 1. 总的点击量 + 1
      buffer(1) = buffer.getLong(1) + 1L
      // 2. 给这个城市的点击量 +1 =>   找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去
      val map: collection.Map = buffer.getMap(0)
      buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
      case _ =>
    }
}

// 分区间的合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs](0)
    val map2 = buffer2.getAs](0)

    val total1: Long = buffer1.getLong(1)
    val total2: Long = buffer2.getLong(1)

    // 1. 总数的聚合
    buffer1(1) = total1 + total2
    // 2. map的聚合
    buffer1(0) = map1.foldLeft(map2) {
      case (map, (cityName, count)) =>
      map + (cityName -> (map.getOrElse(cityName, 0L) + count))
    }

}

// 最终的输出结果
override def evaluate(buffer: Row): Any = {
    // "北京21.2%,天津13.2%,其他65.6%"
    val cityAndCount: collection.Map = buffer.getMap(0)
    val total: Long = buffer.getLong(1)

    val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
    var cityRemarks: List = cityCountTop2.map {
      case (cityName, count) => CityRemark(cityName, count.toDouble / total)
    }
//    CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio))
    cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
    cityRemarks.mkString(",")
}
}

case class CityRemark(cityName: String, cityRatio: Double) {
val formatter = new DecimalFormat("0.00%")

override def toString: String = s"$cityName:${formatter.format(cityRatio)}"

}



运行结果




4 .保存到Mysql

1. 源码

    val props: Properties = new Properties()
    props.put("user","root")
    props.put("password","199712")
    spark.sql(
      """
      |select
      |    area,
      |    product_name,
      |    count,
      |    remark
      |from t3
      |where rk<=3
      |""".stripMargin)
      .coalesce(1)
      .write
      .mode("overwrite")
      .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)


2.运行结果





三. 完整代码

1. udaf

package com.buwenbuhuo.spark.sql.project

import java.text.DecimalFormat

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


/**
**
*
* @author 不温卜火
*         *
* @create 2020-08-06 13:24
**
*         MyCSDN :https://buwenbuhuo.blog.csdn.net/
*
*/
class CityRemarkUDAF extends UserDefinedAggregateFunction {
// 输入数据的类型:北京String
override def inputSchema: StructType = {
    StructType(Array(StructField("city", StringType)))
}

// 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000Map,总的点击量1000/?
override def bufferSchema: StructType = {
    StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
}

// 输出的数据类型"北京21.2%,天津13.2%,其他65.6%"String
override def dataType: DataType = StringType

// 相同的输入是否应用有相同的输出.
override def deterministic: Boolean = true

// 给存储数据初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化map缓存
    buffer(0) = Map()
    // 初始化总的点击量
    buffer(1) = 0L
}

// 分区内合并 Map[城市名, 点击量]
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(cityName: String) =>
      // 1. 总的点击量 + 1
      buffer(1) = buffer.getLong(1) + 1L
      // 2. 给这个城市的点击量 +1 =>   找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去
      val map: collection.Map = buffer.getMap(0)
      buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
      case _ =>
    }
}

// 分区间的合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs](0)
    val map2 = buffer2.getAs](0)

    val total1: Long = buffer1.getLong(1)
    val total2: Long = buffer2.getLong(1)

    // 1. 总数的聚合
    buffer1(1) = total1 + total2
    // 2. map的聚合
    buffer1(0) = map1.foldLeft(map2) {
      case (map, (cityName, count)) =>
      map + (cityName -> (map.getOrElse(cityName, 0L) + count))
    }

}

// 最终的输出结果
override def evaluate(buffer: Row): Any = {
    // "北京21.2%,天津13.2%,其他65.6%"
    val cityAndCount: collection.Map = buffer.getMap(0)
    val total: Long = buffer.getLong(1)

    val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
    var cityRemarks: List = cityCountTop2.map {
      case (cityName, count) => CityRemark(cityName, count.toDouble / total)
    }
//    CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio))
    cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
    cityRemarks.mkString(",")
}
}

case class CityRemark(cityName: String, cityRatio: Double) {
val formatter = new DecimalFormat("0.00%")

override def toString: String = s"$cityName:${formatter.format(cityRatio)}"

}



2. 主程序(具体实现)


package com.buwenbuhuo.spark.sql.project

import java.util.Properties

import org.apache.spark.sql.SparkSession

/**
**
*
* @author 不温卜火
*         *
* @create 2020-08-05 19:01
**
*         MyCSDN :https://buwenbuhuo.blog.csdn.net/
*
*/
object SqlApp {
def main(args: Array): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .master("local")
      .appName("SqlApp")
      .enableHiveSupport()
      .getOrCreate()
    import spark.implicits._
    spark.udf.register("remark",new CityRemarkUDAF)

    // 去执行sql,从hive查询数据
    spark.sql("use spark0806")
    spark.sql(
      """
      |select
      |    ci.*,
      |    pi.product_name,
      |    uva.click_product_id
      |from user_visit_action uva
      |join product_info pi on uva.click_product_id=pi.product_id
      |join city_info ci on uva.city_id=ci.city_id
      |
      |""".stripMargin).createOrReplaceTempView("t1")

    spark.sql(
      """
      |select
      |    area,
      |    product_name,
      |    count(*) count,
      |    remark(city_name) remark
      |from t1
      |group by area, product_name
      |""".stripMargin).createOrReplaceTempView("t2")

    spark.sql(
      """
      |select
      |    area,
      |    product_name,
      |    count,
      |    remark,
      |    rank() over(partition by area order by count desc) rk
      |from t2
      |""".stripMargin).createOrReplaceTempView("t3")

    val props: Properties = new Properties()
    props.put("user","root")
    props.put("password","199712")
    spark.sql(
      """
      |select
      |    area,
      |    product_name,
      |    count,
      |    remark
      |from t3
      |where rk<=3
      |""".stripMargin)
      .coalesce(1)
      .write
      .mode("overwrite")
      .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)


    // 把结果写入到mysql中

    spark.close()
}
}





原文链接:
https://blog.csdn.net/qq_16146103/article/details/107824095

作者:不温卜火

页: [1]
查看完整版本: Spark SQL 项目:实现各区域热门商品前N统计