Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
889 views
in Technique[技术] by (71.8m points)

apache spark - Adding a group count column to a PySpark dataframe

I am coming from R and the tidyverse to PySpark due to its superior Spark handling, and I am struggling to map certain concepts from one context to the other.

In particular, suppose that I had a dataset like the following

x | y
--+--
a | 5
a | 8
a | 7
b | 1

and I wanted to add a column containing the number of rows for each x value, like so:

x | y | n
--+---+---
a | 5 | 3
a | 8 | 3
a | 7 | 3
b | 1 | 1

In dplyr, I would just say:

import(tidyverse)

df <- read_csv("...")
df %>%
    group_by(x) %>%
    mutate(n = n()) %>%
    ungroup()

and that would be that. I can do something almost as simple in PySpark if I'm looking to summarize by number of rows:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.getOrCreate()

spark.read.csv("...") 
    .groupBy(col("x")) 
    .count() 
    .show()

And I thought I understood that withColumn was equivalent to dplyr's mutate. However, when I do the following, PySpark tells me that withColumn is not defined for groupBy data:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count

spark = SparkSession.builder.getOrCreate()

spark.read.csv("...") 
    .groupBy(col("x")) 
    .withColumn("n", count("x")) 
    .show()

In the short run, I can simply create a second dataframe containing the counts and join it to the original dataframe. However, it seems like this could become inefficient in the case of large tables. What is the canonical way to accomplish this?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

When you do a groupBy(), you have to specify the aggregation before you can display the results. For example:

import pyspark.sql.functions as f
data = [
    ('a', 5),
    ('a', 8),
    ('a', 7),
    ('b', 1),
]
df = sqlCtx.createDataFrame(data, ["x", "y"])
df.groupBy('x').count().select('x', f.col('count').alias('n')).show()
#+---+---+
#|  x|  n|
#+---+---+
#|  b|  1|
#|  a|  3|
#+---+---+

Here I used alias() to rename the column. But this only returns one row per group. If you want all rows with the count appended, you can do this with a Window:

from pyspark.sql import Window
w = Window.partitionBy('x')
df.select('x', 'y', f.count('x').over(w).alias('n')).sort('x', 'y').show()
#+---+---+---+
#|  x|  y|  n|
#+---+---+---+
#|  a|  5|  3|
#|  a|  7|  3|
#|  a|  8|  3|
#|  b|  1|  1|
#+---+---+---+

Or if you're more comfortable with SQL, you can register the dataframe as a temporary table and take advantage of pyspark-sql to do the same thing:

df.registerTempTable('table')
sqlCtx.sql(
    'SELECT x, y, COUNT(x) OVER (PARTITION BY x) AS n FROM table ORDER BY x, y'
).show()
#+---+---+---+
#|  x|  y|  n|
#+---+---+---+
#|  a|  5|  3|
#|  a|  7|  3|
#|  a|  8|  3|
#|  b|  1|  1|
#+---+---+---+

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...