SparkSQL使用UDF和UDAF和UDTF
一. UDF
用户自定义函数UDF(User Defined Function)
1.1 固定参数
object WordCount {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("myApp").getOrCreate()
val df = spark.createDataFrame(List(("北京市", "北京市", "昌平区"), ("贵州省", "贵阳市", "南明区")))
df.toDF("province", "city", "district").createOrReplaceTempView("countries")
val customUdf = (sep: String, province: String, city: String, district: String) => {
province + sep + city + sep + district
}
spark.udf.register("customUdf", customUdf)
spark.sql("select customUdf('-', province, city, district) as address from countries").show()
}
}
+-----------------+
| address |
+-----------------+
|北京市-北京市-昌平区|
|贵州省-贵阳市-南明区|
+-----------------+
1.2 可变参数
import org.apache.spark.sql.functions.{array, lit, udf}
import org.apache.spark.sql.SparkSession
object WordCount {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("myApp").getOrCreate()
val df = spark.createDataFrame(List(("北京市", "北京市", "昌平区"), ("贵州省", "贵阳市", "南明区")))
val customDf = df.toDF("province", "city", "district")
val customUdf = (sep: String, column: Seq[Any]) => {
column.mkString(sep)
}
// 原始SQL
customDf.createOrReplaceTempView("countries")
spark.udf.register("customUdf", customUdf)
spark.sql("select customUdf('-', array(province, city, district)) as address from countries").show()
// 使用DSL
import spark.implicits._
val cols = array($"province", $"city", $"district")
val sep = lit("-")
val customConcat = udf(customUdf)
customDf.select(customConcat(sep, cols).alias("address")).show()
}
}
+-----------------+
| address |
+-----------------+
|北京市-北京市-昌平区|
|贵州省-贵阳市-南明区|
+-----------------+
+-----------------+
| address |
+-----------------+
|北京市-北京市-昌平区|
|贵州省-贵阳市-南明区|
+-----------------+
1.3 当SQL需要driver的数据时
可使用 广播 + udf 的方式,比如下面示例
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.SparkSession
object WordCount {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("myApp").getOrCreate()
spark.createDataFrame(
List(("北京市", "北京市", "昌平区"), ("贵州省", "贵阳市", "南明区"), ("加州", "上海市", "海淀区"))
).toDF("province", "city", "district").createOrReplaceTempView("countries")
val targets = Map("北京市" -> "beijing", "贵州省" -> "guizhou")
val targetsRef: Broadcast[Map[String, String]] = spark.sparkContext.broadcast(targets)
spark.udf.register("customConcat", (province: String, city: String, district: String) => {
val provinceMap = targetsRef.value
provinceMap.getOrElse(province, "unknown") + "-" + city + "-" + district
})
spark.sql("select customConcat(province, city, district) as address from countries").show()
}
}
+-------------------+
| address |
+-------------------+
|beijing-北京市-昌平区|
|guizhou-贵阳市-南明区|
|unknown-上海市-海淀区|
+-------------------+
二. UDAF
UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是什么呢,普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。
使用UserDefinedAggregateFunction的步骤如下:
1. 自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现
2. 在spark中注册UDAF,为其绑定一个名字
3. 然后就可以在sql语句中使用上面绑定的名字调用
- 先定义类
object CustomAvg extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(StructField("input", LongType) :: Nil)
override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (input.isNullAt(0)) return
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1)
}
- 注册使用
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
object WordCount {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("myApp").getOrCreate()
spark.createDataFrame(List((100, "a"), (120, "b"), (110, "c"), (115,"d")))
.toDF("age", "name").createOrReplaceTempView("age_table")
spark.udf.register("customAvgUdf", CustomAvg)
spark.sql("select customAvgUdf(age) as age_avg from age_table").show()
}
}
- 查询结果
+-------+
|age_avg|
+-------+
| 111.25|
+-------+
三. UDTF
UDTF(User Defined Table-Generating Functions)
待完善~