一、简单概述
简单来说,该算子能够记录流的累计状态。
updateStateByKey操作允许我们维护任意状态,同时不断地用新信息更新它。
在有新的数据信息进入或更新时,可以让用户保持想要的任何状态。使用这个功能需要完成两步:
定义状态: 可以是任意数据类型
定义状态更新函数: 用一个函数指定如何使用先前的状态,从输入流中的新值更新状态。
对于有状态操作,要不断的把当前和历史的时间切片的RDD累加计算,随着时间的流失,计算的数据规模会变得越来越大;
那么要思考的是如果数据量很大的时候,或者对性能的要求极为苛刻的情况下,可以考虑将数据放在Redis或者tachyon上
注意: updateStateByKey操作,要求必须开启Checkpoint机制。
二、案例代码演示
注意:当我停止运行代码之后再重新启动,checkpoint之前的记录丢失了,不会累计以前的状态。
import java.util.Date
import commons.conf.ConfigurationManager
import commons.constant.Constants
import commons.utils.DateUtils
import org.apache.kafka.common.serialization.StringDeserializer
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.kafka010.{
ConsumerStrategies, KafkaUtils, LocationStrategies}
import org.apache.spark.streaming._
import scala.collection.mutable.ArrayBuffer
object AdverStat {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setAppName("adver").setMaster("local[*]")
val sparkSession = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
// StreamingContext.getActiveOrCreate(checkpointDir, func)
val streamingContext = new StreamingContext(sparkSession.sparkContext, Seconds(5))
val kafka_broker = ConfigurationManager.config.getString(Constants.KAFKA_BROKERS)
val kafka_topics = ConfigurationManager.config.getString(Constants.KAFKA_TOPICS)
val kafkaParam = Map(
"bootstrap.servers" -> kafka_broker,
"key.deserializer" -> classOf[StringDeserializer],
"value.deserializer" -> classOf[StringDeserializer],
"group.id" -> "group1",
/*
auto.offset.reset可以设置为
latest:先去Zookeeper获取offset,如果有直接使用,反之从最新的数据开始消费
,earlist:先去Zookeeper获取offset,如果有直接使用,反之从最开始的数据消费
,none:先去Zookeeper获取offset,如果有直接时使用,反之,直接报错
*/
"auto.offset.reset" -> "latest",
"enable.auto.commit" -> (false: java.lang.Boolean)
)
// 该DStream包含一个个RDD,RDD中又包含一个一个message,message由key和value组成
val adRealTimeDStream = KafkaUtils.createDirectStream[String, String](streamingContext,
LocationStrategies.PreferConsistent,
ConsumerStrategies.Subscribe[String, String](Array(kafka_topics), kafkaParam)
)
// 取出DStream里面每一条message的value
// String:timestamp province city userid adid
val adRealTimeValueDStream = adRealTimeDStream.map(item => item.value())
val adRealTimeFilterDStream = adRealTimeValueDStream.transform {
logRDD =>
//blackListArray:Array[AdBlacklist] AdBlacklist:userId
val blacklistArray = AdBlacklistDAO.findAll()
//
val userIdArray = blacklistArray.map(item => item.userid)
logRDD.filter {
// log:timestamp province city userid adid
case log =>
val logSplit = log.split(" ")
val userId = logSplit(3).toLong
!userIdArray.contains(userId)
}
}
// 需求二用到了updateStateByKey,需要指定checkpoint的目录
streamingContext.checkpoint("./spark-streaming")
// 时间间隔必须是创建streamingContext时时间间隔的倍数
adRealTimeFilterDStream.checkpoint(Duration(10000))
//需求一:时时维护用户黑名单
generateBlackList(adRealTimeFilterDStream)
//需求二:各省城市一天中的各广告点击量
provinceCityClickStat(adRealTimeFilterDStream)
streamingContext.start()
streamingContext.awaitTermination()
}
def provinceCityClickStat(adRealTimeFilterDStream: DStream[String]) = {
val key2ProvinceCityDStream = adRealTimeFilterDStream.map {
case log =>
val logSplit = log.split(" ")
val timeStamp = logSplit(0).toLong
val dateKey = DateUtils.formatDateKey(new Date(timeStamp))
val province = logSplit(1)
val city = logSplit(2)
val adid = logSplit(3)
val key = dateKey + "_" + province + "_" + city + "_" + adid
(key, 1L)
}
val key2StateDStream = key2ProvinceCityDStream.updateStateByKey[Long] {
//状态更新函数
(values: Seq[Long], state: Option[Long]) =>
var newValue = 0L
if (state.isDefined)
newValue = state.get
for (value <- values) {
newValue += value
}
Some(newValue) // 更新后的状态返回
}
key2StateDStream.foreachRDD {
rdd =>
rdd.foreachPartition {
items =>
val adStatArray = new ArrayBuffer[AdStat]()
// key:date_province_city_adid
for ((key, count) <- items) {
val keySplit = key.split("_")
val date = keySplit(0)
val province = keySplit(1)
val city = keySplit(2)
val adid = keySplit(3).toLong
adStatArray += AdStat(date, province, city, adid, count)
}
AdStatDAO.updateBatch(adStatArray.toArray)
}
}
}
def generateBlackList(adRealTimeFilterDStream: DStream[String]) = {
// key2NumDStream:[RDD[(key, 1l)]]
val key2NumDStream = adRealTimeFilterDStream.map {
// log:timestamp province city userid adid
case log =>
val logSplit = log.split(" ")
val timestamp = logSplit(0).toLong
//yy-mm-dd
val dateKey = DateUtils.formatDateKey(new Date(timestamp))
val userId = logSplit(3).toLong
val adid = logSplit(4).toLong
val key = dateKey + "_" + userId + "_" + adid
(key, 1L)
}
val key2CountDstream = key2NumDStream.reduceByKey(_ + _)
// 根据每一个RDD里面的数据更新用户点击次数表
key2CountDstream.foreachRDD {
rdd =>
rdd.foreachPartition {
items =>
val clickCountArray = new ArrayBuffer[AdUserClickCount]()
for ((key, count) <- items) {
val keySplit = key.split("_")
val date = keySplit(0)
val userId = keySplit(1).toLong
val adid = keySplit(2).toLong
clickCountArray += AdUserClickCount(date, userId, adid, count)
}
AdUserClickCountDAO.updateBatch(clickCountArray.toArray)
}
}
val key2BlackListDStream = key2CountDstream.filter {
case (key, value) =>
val keySplit = key.split("_")
val date = keySplit(0)
val userId = keySplit(1).toLong
val adid = keySplit(2).toLong
val clickCount = AdUserClickCountDAO.findClickCountByMultiKey(date, userId, adid)
if (clickCount > 100) {
true
} else {
false
}
}
val userIdDStream = key2BlackListDStream.map {
case (key, count) => key.split("_")(1).toLong
}.transform(rdd => rdd.distinct())
userIdDStream.foreachRDD {
rdd =>
rdd.foreachPartition {
items =>
val userIdArray = new ArrayBuffer[AdBlacklist]()
for (userId <- items) {
userIdArray += AdBlacklist(userId)
}
AdBlacklistDAO.insertBatch(userIdArray.toArray)
}
}
}
}