先来个简单的UDF
场景:
我们有这样一个文本文件:
1^^d2^b^d3^c^d4^^d
在读取数据的时候,第二列的数据如果为空,需要显示'null'
,不为空就直接输出它的值。定义完成后,就可以直接在SparkSQL中使用了。
代码为:
package test;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.sql.DataFrame;import org.apache.spark.sql.Row;import org.apache.spark.sql.RowFactory;import org.apache.spark.sql.SQLContext;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.List;/** * Created by xinghailong on 2017/2/23. */public class test3 { public static void main(String[] args) { //创建spark的运行环境 SparkConf sparkConf = new SparkConf(); sparkConf.setMaster("local[2]"); sparkConf.setAppName("test-udf"); JavaSparkContext sc = new JavaSparkContext(sparkConf); SQLContext sqlContext = new SQLContext(sc); //注册自定义方法 sqlContext.udf().register("isNull", (String field,String defaultValue)->field==null?defaultValue:field, DataTypes.StringType); //读取文件 JavaRDD<String> lines = sc.textFile( "C:\\test-udf.txt" ); JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^"))); List<StructField> structFields = new ArrayList<StructField>(); structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true )); structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true )); structFields.add(DataTypes.createStructField( "c", DataTypes.StringType, true )); StructType structType = DataTypes.createStructType( structFields ); DataFrame test = sqlContext.createDataFrame( rows, structType); test.registerTempTable("test"); sqlContext.sql("SELECT con_join(c,b) FROM test GROUP BY a").show(); sc.stop(); } }
输出内容为:
+---+----+---+ | a| _c1| c| +---+----+---+ | 1|null| d| | 2| b| d| | 3| c| d| | 4|null| d| +---+----+---+
其中比较关键的就是这句:
sqlContext.udf().register("isNull", (String field,String defaultValue)->field==null?defaultValue:field, DataTypes.StringType);
这里我直接用的java8的语法写的,如果是java8之前的版本,需要使用Function2创建匿名函数。
再来个自定义的UDAF—求平均数
先来个最简单的UDAF,求平均数。类似这种的操作有很多,比如最大值,最小值,累加,拼接等等,都可以采用相同的思路来做。
首先是需要定义UDAF函数
package test;import org.apache.spark.sql.Row;import org.apache.spark.sql.expressions.MutableAggregationBuffer;import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;import org.apache.spark.sql.types.DataType;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.List;/** * Created by xinghailong on 2017/2/23. */public class MyAvg extends UserDefinedAggregateFunction { @Override public StructType inputSchema() { List<StructField> structFields = new ArrayList<>(); structFields.add(DataTypes.createStructField( "field1", DataTypes.StringType, true )); return DataTypes.createStructType( structFields ); } @Override public StructType bufferSchema() { List<StructField> structFields = new ArrayList<>(); structFields.add(DataTypes.createStructField( "field1", DataTypes.IntegerType, true )); structFields.add(DataTypes.createStructField( "field2", DataTypes.IntegerType, true )); return DataTypes.createStructType( structFields ); } @Override public DataType dataType() { return DataTypes.IntegerType; } @Override public boolean deterministic() { return false; } @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0,0); buffer.update(1,0); } @Override public void update(MutableAggregationBuffer buffer, Row input) { buffer.update(0,buffer.getInt(0)+1); buffer.update(1,buffer.getInt(1)+Integer.valueOf(input.getString(0))); } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { buffer1.update(0,buffer1.getInt(0)+buffer2.getInt(0)); buffer1.update(1,buffer1.getInt(1)+buffer2.getInt(1)); } @Override public Object evaluate(Row buffer) { return buffer.getInt(1)/buffer.getInt(0); } }
使用的时候,需要先注册,然后在spark sql里面就可以直接使用了:
package test;import com.tgou.standford.misdw.udf.MyAvg;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.sql.DataFrame;import org.apache.spark.sql.Row;import org.apache.spark.sql.RowFactory;import org.apache.spark.sql.SQLContext;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.List;/** * Created by xinghailong on 2017/2/23. */public class test4 { public static void main(String[] args) { SparkConf sparkConf = new SparkConf(); sparkConf.setMaster("local[2]"); sparkConf.setAppName("test"); JavaSparkContext sc = new JavaSparkContext(sparkConf); SQLContext sqlContext = new SQLContext(sc); sqlContext.udf().register("my_avg",new MyAvg()); JavaRDD<String> lines = sc.textFile( "C:\\test4.txt" ); JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^"))); List<StructField> structFields = new ArrayList<StructField>(); structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true )); structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true )); StructType structType = DataTypes.createStructType( structFields ); DataFrame test = sqlContext.createDataFrame( rows, structType); test.registerTempTable("test"); sqlContext.sql("SELECT my_avg(b) FROM test GROUP BY a").show(); sc.stop(); } }
计算的文本内容为:
a^3 a^6 b^2 b^4 b^6
再来个无所不能的UDAF
真正的业务场景里面,总会有千奇百怪的需求,比如:
想要按照某个字段分组,取其中的一个最大值
想要按照某个字段分组,对分组内容的数据按照特定字段统计累加
想要按照某个字段分组,针对特定的条件,拼接字符串
再比如一个场景,需要按照某个字段分组,然后分组内的数据,又需要按照某一列进行去重,最后再计算值
1 按照某个字段分组
2 分组后去重
3 累加指标值
如果不用UDAF,你要是写spark可能需要这样做:
rdd.groupBy(r->r.xxx) .map(t2->{ HashMap<String,String> map = new HashMap<>(); for(Object p : t2._2){ map.put(xx,yyy) } return map.values().stream().reduce(0,Integer::sum); });
上面是一段伪码,不保证正常运行哈。
这样写,其实也能应付需求了,但是代码显得略有点丑陋。还是不如SparkSQL看的清晰明了...
所以我们再尝试用SparkSql中的UDAF来一版!
首先需要创建UDAF类
import org.apache.commons.lang.StringUtils;import org.apache.spark.sql.Row;import org.apache.spark.sql.expressions.MutableAggregationBuffer;import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;import org.apache.spark.sql.types.*;import java.util.*;/** * * Created by xinghailong on 2017/2/23. */public class ConditionJoinUDAF extends UserDefinedAggregateFunction { @Override public StructType inputSchema() { List<StructField> structFields = new ArrayList<>(); structFields.add(DataTypes.createStructField( "field1", DataTypes.IntegerType, true )); structFields.add(DataTypes.createStructField( "field2", DataTypes.StringType, true )); return DataTypes.createStructType( structFields ); } @Override public StructType bufferSchema() { List<StructField> structFields = new ArrayList<>(); structFields.add(DataTypes.createStructField( "field", DataTypes.StringType, true )); return DataTypes.createStructType( structFields ); } @Override public DataType dataType() { return DataTypes.StringType; } @Override public boolean deterministic() {//是否强制每次执行的结果相同 return false; } @Override public void initialize(MutableAggregationBuffer buffer) {//初始化 buffer.update(0,""); } @Override public void update(MutableAggregationBuffer buffer, Row input) {//相同的executor间的数据合并 Integer bs = input.getInt(0); String field = buffer.getString(0); String in = input.getString(1); if(bs > 0 && !"".equals(in) && !field.contains(in)){ field += ","+in; } buffer.update(0,field); } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {//不同excutor间的数据合并 String field1 = buffer1.getString(0); String field2 = buffer2.getString(0); if(!"".equals(field2)){ field1 += ","+field2; } buffer1.update(0,field1); } @Override public Object evaluate(Row buffer) {//根据Buffer计算结果 return StringUtils.join(Arrays.stream(buffer.getString(0).split(",")).filter(line->!line.equals("")).toArray(),","); } }
拿一个例子坐下实验:
a^1111^2a^1111^2a^1111^2a^1111^2a^1111^2a^2222^0a^3333^1b^4444^0b^5555^3c^6666^0
按照第一列进行分组,不同的第三列值,进行拼接。
package test;import test.ConditionJoinUDAF;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.sql.DataFrame;import org.apache.spark.sql.Row;import org.apache.spark.sql.RowFactory;import org.apache.spark.sql.SQLContext;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.List;/** * Created by xinghailong on 2017/2/23. */public class test2 { public static void main(String[] args) { SparkConf sparkConf = new SparkConf(); sparkConf.setMaster("local[2]"); sparkConf.setAppName("test"); JavaSparkContext sc = new JavaSparkContext(sparkConf); SQLContext sqlContext = new SQLContext(sc); sqlContext.udf().register("con_join",new ConditionJoinUDAF()); JavaRDD<String> lines = sc.textFile( "C:\\test2.txt" ); JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^"))); List<StructField> structFields = new ArrayList<StructField>(); structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true )); structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true )); structFields.add(DataTypes.createStructField( "c", DataTypes.StringType, true )); StructType structType = DataTypes.createStructType( structFields ); DataFrame test = sqlContext.createDataFrame( rows, structType); test.registerTempTable("test"); sqlContext.sql("SELECT con_join(c,b) FROM test GROUP BY a").show(); sc.stop(); } }
这样SQL简洁明了,就能表达意思了。