spark自定义函数实现

场景:由于系统函数无法满足实际开发需求,需要通过自定义函数来实现

示例:


package spark

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}

object TestSparkUdf {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("student")
      .master("local[2]")
      .getOrCreate()
    import spark.implicits._
    val rdd2 = spark.sparkContext.makeRDD(Array(Student2(18, "one"), Student2(20, "two")))
    rdd2.toDF().registerTempTable("student")

    spark.udf.register("myupper", myUpper _)
    val df = spark.sql("select myupper(name) from student")
    df.show()
//    +-----------------+
//    |UDF:myupper(name)|
//    +-----------------+
//    |              ONE|
//    |              TWO|
//    +-----------------+
    spark.udf.register("myavg", new myAvg())
    val df2 = spark.sql("select myavg(age) from student")
    df2.show()
//    +----------+
//    |myavg(age)|
//    +----------+
//    |        19|
//    +----------+
    spark.stop()

  }

  //udf函数 一对一
  def myUpper(str: String): String = str.toUpperCase()

}
//case class Student(id: String, name:String)

class myAvg extends UserDefinedAggregateFunction {
  //输入数据的结构
  override def inputSchema: StructType = StructType(Array(StructField("age", LongType)))
  //缓冲区的数据结构
  override def bufferSchema: StructType = StructType(Array(StructField("total", LongType), StructField("count", LongType)))
  //函数计算结果的数据类型
  override def dataType: DataType = LongType
  //函数的稳定性
  override def deterministic: Boolean = true
  //缓冲区的初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L;
    buffer(1) = 0L;
  }
  //新数据过来,如何更新缓冲区
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer.update(0, buffer.getLong(0) + input.getLong(0))
    buffer.update(1, buffer.getLong(1) + 1)
  }
  //多个缓冲区数据合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
    buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
  }
  //计算操作结果
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0) / buffer.getLong(1)
  }
}

case class Student2(age: Long, name: String)

相关推荐

  1. spark定义函数实现

    2024-05-16 08:14:03       33 阅读
  2. solr functionquery函数查询定义函数实现

    2024-05-16 08:14:03       47 阅读
  3. hive定义函数

    2024-05-16 08:14:03       33 阅读
  4. SQL 定义函数

    2024-05-16 08:14:03       28 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-05-16 08:14:03       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-05-16 08:14:03       106 阅读
  3. 在Django里面运行非项目文件

    2024-05-16 08:14:03       87 阅读
  4. Python语言-面向对象

    2024-05-16 08:14:03       96 阅读

热门阅读

  1. 入门篇:Kafka基础知识·

    2024-05-16 08:14:03       37 阅读
  2. K-means 算法【python,算法,机器学习】

    2024-05-16 08:14:03       36 阅读
  3. mediasoup源码分析(三)--日志模块

    2024-05-16 08:14:03       28 阅读
  4. [前端|vue] !important 关键字使用说明(笔记)

    2024-05-16 08:14:03       35 阅读
  5. 导出docker中gitlab的数据

    2024-05-16 08:14:03       29 阅读
  6. [linux] bash中的单引号(‘)和双引号(“)

    2024-05-16 08:14:03       28 阅读
  7. Hadoop、MapReduce、YARN和Spark的区别与联系

    2024-05-16 08:14:03       34 阅读
  8. Spring的IOC(Inversion of Control)设计模式

    2024-05-16 08:14:03       27 阅读
  9. AI学习指南概率论篇-贝叶斯推断

    2024-05-16 08:14:03       35 阅读