SparkSql 实现两表查询

项目需求:

ip.txt:包含ip起始地址,ip结束地址,ip所属省份

access.txt:包含ip地址和各种访问数据

需求:两表联合查询每个省份的ip数量

一、将两张表的数据提取出来,转换成DataFrame,创建两个view。实现join查询

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Dataset, SparkSession}

object IPDemo {
  Logger.getLogger("org.apache.spark").setLevel(Level.OFF)
  val ipFile = ("d:\\data\\spark\\ip.txt")
  val acessFile = "d:\\data\\spark\\access.log"
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("Ip").master("local[*]").getOrCreate()
    import spark.implicits._
    //读取ip文件
    val ipFile = spark.read.textFile("d:\\data\\spark\\ip.txt")
    //整理ip文件
    val ipRules: Dataset[(Long, Long, String)] = ipFile.map(line => {
      val splited = line.split("[|]")
      val startNum = splited(2).toLong
      val endNum = splited(3).toLong
      val province = splited(6)
      (startNum,endNum,province)
    })
    //加入元数据
    val ipDF = ipRules.toDF("start_num","end_num","province")
    //将ip注册成view
    ipDF.createTempView("t_ip")
    //读取访问日志文件
    val access_file = spark.read.textFile(acessFile)
    import day07.MyUtils
    val accessDF = access_file.map(line =>{
      val fields = line.split("[|]")
      val ip = fields(1)
      MyUtils.ip2Long(ip)
    }).toDF("ip")
    //将访问日志整理成视图
    accessDF.createTempView("t_access")
    //sql语句 关联两张表
    val result = spark.sql("SELECT province,count(*) counts FROM t_ip JOIN t_access ON ip>=start_num and ip<=end_num GROUP BY province ORDER BY counts DESC")
    result.show();
    spark.stop()

  }

}

二、改进方法

两表join,如果数据量太大,就会导致运行速度变慢。所以将ip的数据以广播的方式发送到Executor。构建一个自定义方法,进行查询。

import day07.MyUtils
import org.apache.spark.sql.{Dataset, SparkSession}

object IpLocation {
  val ipFile = "d:\\data\\spark\\ip.txt"
  val acessFile = "d:\\data\\spark\\access.log"
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("SQLIPLocation").master("local[*]").getOrCreate()
    //隐式转换
    import  spark.implicits._
    //读取ip文件
    val ipFile = spark.read.textFile("d:\\data\\spark\\ip.txt")
    //整理ip文件
    val ipRules: Dataset[(Long, Long, String)] = ipFile.map(line => {
      val splited = line.split("[|]")
      val startNum = splited(2).toLong
      val endNum = splited(3).toLong
      val province = splited(6)
      (startNum,endNum,province)
    })
    //加入元数据
    //val ipDF = ipRules.toDF("start_num","end_num","province")

    //将全部的IP规则收集到Driver端
    val ipRulesDriver = ipRules.collect()
    //广播 阻塞的方法 没有广播完,就不会向下
    val broadcastRef = spark.sparkContext.broadcast(ipRulesDriver)

    //读取web日志
    val accessLogLines = spark.read.textFile(acessFile)
    val ips = accessLogLines.map(line => {
      val Fields = line.split("[|]")
      val ip = Fields(1)
      MyUtils.ip2Long(ip)
    }).toDF("ip_num")
    //将访问日志数据注册成视图
    ips.createTempView("access_ip")

    //定义并注册自定义函数
    //自定义函数在哪里定义的?  (Driver)  业务逻辑在Executor执行
    spark.udf.register("ip_num2Province",(ip_num:Long)=>{
      //获取广播到Driver
      //根据Driver端的广播变量引用,在发送task时,会将Driver端的引用伴随着发送到Executor
      val rulesExecute: Array[(Long, Long, String)] = broadcastRef.value
      val index = MyUtils.binarySearch(rulesExecute,ip_num)
      var province = "未知"
      if(index != -1){
        province = rulesExecute(index)._3
      }
      province
    })

    val result = spark.sql("select ip_num2Province(ip_num) province,count(*) counts from access_ip group by province order by counts desc")

    result.show()

    spark.stop()

  }

}

三、用到的工具包代码如下:

import java.sql.{Connection, DriverManager, PreparedStatement}

/**
  * Created by zx on 2017/12/12.
  */
object MyUtils {

//将ip转换成数字类型
  def ip2Long(ip:String):Long ={
    val fragments = ip.split("[.]")
    var ipNum =0L
    for(i<- 0 until fragments.length){
      ipNum = fragments(i).toLong | ipNum << 8L
    }
    ipNum
  }



//查找某个ip所属的省份
  def binarySearch(lines: Array[(Long,Long,String)],ip: Long):Int ={
    var low =0
    var high =lines.length-1
    while(low <=high){
      val middle =(low+high)/2
      if((ip>=lines(middle)._1) && (ip<=lines(middle)._2))
        return middle
      if(ip < lines(middle)._1)
        high=middle -1
      else{
        low =middle +1
      }
    }
    -1
  }

//连接mysql 插入数据
  def data2MySQL(iter:Iterator[(String,Int)])={
    val conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/test","root","123456")
    val ps = conn.prepareStatement("insert into access_log values (?,?)")
    iter.foreach(x =>{
      ps.setString(1,x._1)
      ps.setInt(2,x._2)
      ps.executeUpdate()
    })
    if(conn!=null){
      conn.close()
    }
    if(ps!=null){
      ps.close()
    }
  }

}

猜你喜欢

转载自blog.csdn.net/qq_32539825/article/details/82906911