Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
155 views
in Technique[技术] by (71.8m points)

performance - DataFrame / Dataset groupBy behaviour/optimization

Suppose we have DataFrame df consisting of the following columns:

Name, Surname, Size, Width, Length, Weigh

Now we want to perform a couple of operations, for example we want to create a couple of DataFrames containing data about Size and Width.

val df1 = df.groupBy("surname").agg( sum("size") )
val df2 = df.groupBy("surname").agg( sum("width") )

as you can notice, other columns, like Length are not used anywhere. Is Spark smart enough to drop the redundant columns before the shuffling phase or are they carried around? Wil running:

val dfBasic = df.select("surname", "size", "width")

before grouping somehow affect the performance?

Question&Answers:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Yes, it is "smart enough". groupBy performed on a DataFrame is not the same operation as groupBy performed on a plain RDD. In a scenario you've described there is no need to move raw data at all. Let's create a small example to illustrate that:

val df = sc.parallelize(Seq(
   ("a", "foo", 1), ("a", "foo", 3), ("b", "bar", 5), ("b", "bar", 1)
)).toDF("x", "y", "z")

df.groupBy("x").agg(sum($"z")).explain

// == Physical Plan ==
// *HashAggregate(keys=[x#148], functions=[sum(cast(z#150 as bigint))])
// +- Exchange hashpartitioning(x#148, 200)
//    +- *HashAggregate(keys=[x#148], functions=[partial_sum(cast(z#150 as bigint))])
//       +- *Project [_1#144 AS x#148, _3#146 AS z#150]
//          +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#144, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#145, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#146]
//             +- Scan ExternalRDDScan[obj#143]

As you can the first phase is a projection where only required columns are preserved. Next data is aggregated locally and finally transferred and aggregated globally. You'll get a little bit different answer output if you use Spark <= 1.4 but general structure should be exactly the same.

Finally a DAG visualization showing that above description describes actual job:

group by and agg DAG

Similarly, Dataset.groupByKey followed by reduceGroups, contains both map-side (ObjectHashAggregate with partial_reduceaggregator) and reduce-side (ObjectHashAggregate with reduceaggregator reduction):

case class Foo(x: String, y: String, z: Int)

val ds = df.as[Foo]
ds.groupByKey(_.x).reduceGroups((x, y) => x.copy(z = x.z + y.z)).explain

// == Physical Plan ==
// ObjectHashAggregate(keys=[value#126], functions=[reduceaggregator(org.apache.spark.sql.expressions.ReduceAggregator@54d90261, Some(newInstance(class $line40.$read$$iw$$iw$Foo)), Some(class $line40.$read$$iw$$iw$Foo), Some(StructType(StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false))), input[0, scala.Tuple2, true]._1 AS value#128, if ((isnull(input[0, scala.Tuple2, true]._2) || None.equals)) null else named_struct(x, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).x, true, false) AS x#25, y, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).y, true, false) AS y#26, z, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).z AS z#27) AS _2#129, newInstance(class scala.Tuple2), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).x, true, false) AS x#25, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).y, true, false) AS y#26, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).z AS z#27, StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false), true, 0, 0)])
// +- Exchange hashpartitioning(value#126, 200)
//    +- ObjectHashAggregate(keys=[value#126], functions=[partial_reduceaggregator(org.apache.spark.sql.expressions.ReduceAggregator@54d90261, Some(newInstance(class $line40.$read$$iw$$iw$Foo)), Some(class $line40.$read$$iw$$iw$Foo), Some(StructType(StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false))), input[0, scala.Tuple2, true]._1 AS value#128, if ((isnull(input[0, scala.Tuple2, true]._2) || None.equals)) null else named_struct(x, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).x, true, false) AS x#25, y, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).y, true, false) AS y#26, z, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).z AS z#27) AS _2#129, newInstance(class scala.Tuple2), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).x, true, false) AS x#25, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).y, true, false) AS y#26, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).z AS z#27, StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false), true, 0, 0)])
//       +- AppendColumns <function1>, newInstance(class $line40.$read$$iw$$iw$Foo), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#126]
//          +- *Project [_1#4 AS x#8, _2#5 AS y#9, _3#6 AS z#10]
//             +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#4, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#5, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#6]
//                +- Scan ExternalRDDScan[obj#3]

groupByKey + reduceGroups

However other methods of KeyValueGroupedDataset might work similarly to RDD.groupByKey. For example mapGroups (or flatMapGroups) doesn't use partial aggregation.

ds.groupByKey(_.x)
  .mapGroups((_, iter) => iter.reduce((x, y) => x.copy(z = x.z + y.z)))
  .explain

//== Physical Plan ==
//*SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).x, true, false) AS x#37, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).y, true, false) AS y#38, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).z AS z#39]
//+- MapGroups <function2>, value#32.toString, newInstance(class $line15.$read$$iw$$iw$Foo), [value#32], [x#8, y#9, z#10], obj#36: $line15.$read$$iw$$iw$Foo
//   +- *Sort [value#32 ASC NULLS FIRST], false, 0
//      +- Exchange hashpartitioning(value#32, 200)
//         +- AppendColumns <function1>, newInstance(class $line15.$read$$iw$$iw$Foo), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#32]
//            +- *Project [_1#4 AS x#8, _2#5 AS y#9, _3#6 AS z#10]
//               +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#4, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#5, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#6]
//                  +- Scan ExternalRDDScan[obj#3]

groupByKey + mapGroups


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...