You can do a conditional count using count_if
:
df.createOrReplaceTempView('df')
result = spark.sql("""
select *,
count_if(working_day_flag = 'Y')
over(partition by country_alpha, trunc(dt_bord_ref, 'month') order by dt_bord_ref)
month_to_date,
count_if(working_day_flag = 'Y')
over(partition by country_alpha, trunc(dt_bord_ref, 'month') order by dt_bord_ref
rows between 1 following and unbounded following)
month_to_go
from df
""")
result.show()
+-------------------+-------------+----------------+-------------+-----------+
| DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|month_to_date|month_to_go|
+-------------------+-------------+----------------+-------------+-----------+
|2021-01-01 00:00:00| BRA| N| 0| 3|
|2021-01-02 00:00:00| BRA| N| 0| 3|
|2021-01-03 00:00:00| BRA| N| 0| 3|
|2021-01-04 00:00:00| BRA| Y| 1| 2|
|2021-01-05 00:00:00| BRA| Y| 2| 1|
|2021-01-06 00:00:00| BRA| Y| 3| 0|
|2021-01-01 00:00:00| ITA| N| 0| 3|
|2021-01-02 00:00:00| ITA| N| 0| 3|
|2021-01-03 00:00:00| ITA| N| 0| 3|
|2021-01-04 00:00:00| ITA| Y| 1| 2|
|2021-01-05 00:00:00| ITA| Y| 2| 1|
|2021-01-06 00:00:00| ITA| N| 2| 1|
|2021-01-07 00:00:00| ITA| Y| 3| 0|
|2021-01-01 00:00:00| FRA| N| 0| 3|
|2021-01-02 00:00:00| FRA| N| 0| 3|
|2021-01-03 00:00:00| FRA| N| 0| 3|
|2021-01-04 00:00:00| FRA| Y| 1| 2|
|2021-01-05 00:00:00| FRA| Y| 2| 1|
|2021-01-06 00:00:00| FRA| Y| 3| 0|
+-------------------+-------------+----------------+-------------+-----------+
If you want a similar solution in Pyspark API:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
result = df.withColumn(
'month_to_date',
F.count(
F.when(F.col('working_day_flag') == 'Y', 1)
).over(
Window.partitionBy('country_alpha', F.trunc('dt_bord_ref', 'month'))
.orderBy('dt_bord_ref')
)
).withColumn(
'month_to_go',
F.count(
F.when(F.col('working_day_flag') == 'Y', 1)
).over(
Window.partitionBy('country_alpha', F.trunc('dt_bord_ref', 'month'))
.orderBy('dt_bord_ref')
.rowsBetween(1, Window.unboundedFollowing)
)
)
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…