博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
SparkSQL之UDAF使用
阅读量:5833 次
发布时间:2019-06-18

本文共 2985 字,大约阅读时间需要 9 分钟。

1.创建一个类继承UserDefinedAggregateFunction类。

---------------------------------------------------------------------

package cn.piesat.test import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, DataTypes, IntegerType, StructType} class CountUDAF extends UserDefinedAggregateFunction{
/** * 聚合函数的输入类型 * @return */ override def inputSchema: StructType = {
new StructType().add("ageType",IntegerType) } /** * 缓存的数据类型 * @return */ override def bufferSchema: StructType = {
new StructType().add("bufferAgeType",IntegerType) } /** * UDAF返回值的类型 * @return */ override def dataType: DataType = {
DataTypes.StringType } /** * 如果该函数是确定性的,那么将会返回true,一般给true就行。 * @return */ override def deterministic: Boolean = true /** * 为每个分组的数据执行初始化操作 * @param buffer */ override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0 } /** * 更新操作,指的是每个分组有新的值进来的时候,如何进行分组对应的聚合值的计算 * @param buffer * @param input */ override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val num= input.getAs[Int](0) buffer(0)=buffer.getAs[Int](0)+num } /** * 分区合并时执行的操作 * @param buffer1 * @param buffer2 */ override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getAs[Int](0)+buffer2.getAs[Int](0) } /** * 最后返回的结果 * @param buffer * @return */ override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0).toString } } -------------------------------------------------------------- 2.在main函数中使用样例 ---------------------------------------------------------------
package cn.piesat.test import org.apache.spark.sql.SparkSession import scala.collection.mutable.ArrayBuffer object SparkSQLTest {
def main(args: Array[String]): Unit = {
val spark=SparkSession.builder().appName("sparkSql").master("local[4]") .config("spark.serializer","org.apache.spark.serializer.KryoSerializer").getOrCreate() val sc=spark.sparkContext val sqlContext=spark.sqlContext val workerRDD=sc.textFile("F://Workers.txt").mapPartitions(itor=>{
val array=new ArrayBuffer[Worker]() while(itor.hasNext){
val splited=itor.next().split(",") array.append(new Worker(splited(0),splited(2).toInt,splited(2))) } array.toIterator }) import spark.implicits._ //注册UDAF spark.udf.register("countUDF",new CountUDAF()) val workDS=workerRDD.toDS() workDS.createOrReplaceTempView("worker") val resultDF=spark.sql("select countUDF(age) from worker") val resultDS=resultDF.as("WO") resultDS.show() spark.stop() } } -----------------------------------------------------------------------------------------------

转载于:https://www.cnblogs.com/runnerjack/p/10662338.html

你可能感兴趣的文章
云计算最大难处
查看>>
关于数据分析思路的4点心得
查看>>
Memcached安装与配置
查看>>
美团数据仓库的演进
查看>>
SAP被评为“大数据”预测分析领军企业
查看>>
联想企业网盘张跃华:让文件创造业务价值
查看>>
记录一次蚂蚁金服前端电话面试
查看>>
直播源码开发视频直播平台,不得不了解的流程
查看>>
Ubuntu上的pycrypto给出了编译器错误
查看>>
聊聊flink的RestClientConfiguration
查看>>
在CentOS上搭建git仓库服务器以及mac端进行克隆和提交到远程git仓库
查看>>
測試文章
查看>>
Flex很难?一文就足够了
查看>>
【BATJ面试必会】JAVA面试到底需要掌握什么?【上】
查看>>
CollabNet_Subversion小结
查看>>
mysql定时备份自动上传
查看>>
17岁时少年决定把海洋洗干净,现在21岁的他做到了
查看>>
linux 启动oracle
查看>>
《写给大忙人看的java se 8》笔记
查看>>
倒计时:计算时间差
查看>>