已复制
全屏展示
复制代码

SparkSQL使用UDF和UDAF和UDTF


· 3 min read

一. UDF

用户自定义函数UDF(User Defined Function)

1.1 固定参数

object WordCount {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]").appName("myApp").getOrCreate()

    val df = spark.createDataFrame(List(("北京市", "北京市", "昌平区"), ("贵州省", "贵阳市", "南明区")))
    df.toDF("province", "city", "district").createOrReplaceTempView("countries")

    val customUdf = (sep: String, province: String, city: String, district: String) => {
      province + sep + city + sep + district
    }
    
    spark.udf.register("customUdf", customUdf)
    spark.sql("select customUdf('-', province, city, district) as address from countries").show()
  }
}
+-----------------+
|    address      |
+-----------------+
|北京市-北京市-昌平区|
|贵州省-贵阳市-南明区|
+-----------------+

1.2 可变参数

import org.apache.spark.sql.functions.{array, lit, udf}
import org.apache.spark.sql.SparkSession

object WordCount {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]").appName("myApp").getOrCreate()
    val df = spark.createDataFrame(List(("北京市", "北京市", "昌平区"), ("贵州省", "贵阳市", "南明区")))
    val customDf = df.toDF("province", "city", "district")

    val customUdf = (sep: String, column: Seq[Any]) => {
      column.mkString(sep)
    }

    // 原始SQL
    customDf.createOrReplaceTempView("countries")
    spark.udf.register("customUdf", customUdf)
    spark.sql("select customUdf('-', array(province, city, district)) as address from countries").show()

    // 使用DSL
    import spark.implicits._
    val cols = array($"province", $"city", $"district")
    val sep = lit("-")
    val customConcat = udf(customUdf)
    customDf.select(customConcat(sep, cols).alias("address")).show()
  }
}
+-----------------+
|    address      |
+-----------------+
|北京市-北京市-昌平区|
|贵州省-贵阳市-南明区|
+-----------------+

+-----------------+
|    address      |
+-----------------+
|北京市-北京市-昌平区|
|贵州省-贵阳市-南明区|
+-----------------+

1.3 当SQL需要driver的数据时

可使用 广播 + udf 的方式,比如下面示例

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.SparkSession

object WordCount {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]").appName("myApp").getOrCreate()
    spark.createDataFrame(
      List(("北京市", "北京市", "昌平区"), ("贵州省", "贵阳市", "南明区"), ("加州", "上海市", "海淀区"))
    ).toDF("province", "city", "district").createOrReplaceTempView("countries")

    val targets = Map("北京市" -> "beijing", "贵州省" -> "guizhou")
    val targetsRef: Broadcast[Map[String, String]] = spark.sparkContext.broadcast(targets)

    spark.udf.register("customConcat", (province: String, city: String, district: String) => {
      val provinceMap = targetsRef.value
      provinceMap.getOrElse(province, "unknown") + "-" + city + "-" + district
    })
    
    spark.sql("select customConcat(province, city, district) as address from countries").show()
  }
}
+-------------------+
|        address    |
+-------------------+
|beijing-北京市-昌平区|
|guizhou-贵阳市-南明区|
|unknown-上海市-海淀区|
+-------------------+

二. UDAF

UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是什么呢,普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。

使用UserDefinedAggregateFunction的步骤如下:

1. 自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现

2. 在spark中注册UDAF,为其绑定一个名字

3. 然后就可以在sql语句中使用上面绑定的名字调用

  • 先定义类
object CustomAvg extends UserDefinedAggregateFunction {
  override def inputSchema: StructType = StructType(StructField("input", LongType) :: Nil)

  override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)

  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (input.isNullAt(0)) return
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1)
}
  • 注册使用
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}

object WordCount {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]").appName("myApp").getOrCreate()
    spark.createDataFrame(List((100, "a"), (120, "b"), (110, "c"), (115,"d")))
      .toDF("age", "name").createOrReplaceTempView("age_table")
    spark.udf.register("customAvgUdf", CustomAvg)
    spark.sql("select customAvgUdf(age) as age_avg from age_table").show()
  }
}
  • 查询结果
+-------+
|age_avg|
+-------+
| 111.25|
+-------+

三. UDTF

UDTF(User Defined Table-Generating Functions)

待完善~

🔗

文章推荐