# /usr/bin/env python3

# pylint: disable=C0413,C0411,C0116,C0114,E1101,W0621

import pyspark.sql.functions as F
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

# tag::ch10-read-parquet[]
gsod = spark.read.parquet("./data/window/gsod.parquet")
# end::ch10-read-parquet[]

# tag::code-window-lowest-group-by[]

coldest_temp = gsod.groupby("year").agg(F.min("temp").alias("temp"))
coldest_temp.orderBy("temp").show()

# +----+------+
# |year|  temp|
# +----+------+
# |2017|-114.7|
# |2018|-113.5|
# |2019|-114.7|
# +----+------+

# end::code-window-lowest-group-by[]

# tag::code-window-left-semi-self-join[]

coldest_when = gsod.join(
    coldest_temp, how="left_semi", on=["year", "temp"]
).select("stn", "year", "mo", "da", "temp")

coldest_when.orderBy("year", "mo", "da").show()

# +------+----+---+---+------+
# |   stn|year| mo| da|  temp|
# +------+----+---+---+------+
# |896250|2017| 06| 20|-114.7|
# |896060|2018| 08| 27|-113.5|
# |895770|2019| 06| 15|-114.7|
# +------+----+---+---+------+

# end::code-window-left-semi-self-join[]

# tag::code-window-create-window[]

from pyspark.sql.window import Window  # <1>

each_year = Window.partitionBy("year")  # <2>

print(each_year)
# <pyspark.sql.window.WindowSpec object at 0x7f978fc8e6a0>

# end::code-window-create-window[]

# fmt:off
# tag::code-window-apply-window[]

(gsod
 .withColumn("min_temp", F.min("temp").over(each_year))  # <1>
 .where("temp = min_temp")
 .select("year", "mo", "da", "stn", "temp")
 .orderBy("year", "mo", "da")
 .show())
# +----+---+---+------+------+
# |year| mo| da|   stn|  temp|
# +----+---+---+------+------+
# |2017| 06| 20|896250|-114.7|
# |2018| 08| 27|896060|-113.5|
# |2019| 06| 15|895770|-114.7|
# +----+---+---+------+------+

# end::code-window-apply-window[]
# fmt:on

# tag::code-window-apply-window2[]

gsod.select(
    "year",
    "mo",
    "da",
    "stn",
    "temp",
    F.min("temp").over(each_year).alias("min_temp"),
).where(
    "temp = min_temp"
).drop(  # <1>
    "min_temp"
).orderBy(
    "year", "mo", "da"
).show()

# end::code-window-apply-window2[]


# tag::code-window-way-speed[]


def the_self_join_way(gsod):
    coldest_temp = gsod.groupby("year").agg(F.min("temp").alias("temp"))

    return gsod.join(
        coldest_temp, how="left_semi", on=["year", "temp"]
    ).select("stn", "year", "mo", "da", "temp")


def the_window_way(gsod):
    min_temp_window = Window.partitionBy("year")

    return (
        gsod.withColumn("min_temp", F.min("temp").over(min_temp_window))
        .where("temp = min_temp")
        .select("year", "mo", "da", "stn", "temp")
    )


# %timeit the_self_loin_way(gsod).show()
# [...]
#

# %timeit the_window_way(gsod).show()
# [...]
#

# end::code-window-way-speed[]

# tag::code-window-read-gsod-light[]

gsod_light = spark.read.parquet("./data/window/gsod_light.parquet")

gsod_light.show()
# +------+----+---+---+----+----------+
# |   stn|year| mo| da|temp|count_temp|
# +------+----+---+---+----+----------+
# |994979|2017| 12| 11|21.3|        21|
# |998012|2017| 03| 02|31.4|        24|
# |719200|2017| 10| 09|60.5|        11|
# |917350|2018| 04| 21|82.6|         9|
# |076470|2018| 06| 07|65.0|        24|
# |996470|2018| 03| 12|55.6|        12|
# |041680|2019| 02| 19|16.1|        15|
# |949110|2019| 11| 23|54.9|        14|
# |998252|2019| 04| 18|44.7|        11|
# |998166|2019| 03| 20|34.8|        12|
# +------+----+---+---+----+----------+

# end::code-window-read-gsod-light[]

# tag::code-window-orderby[]

#                                  <1>
#                           -----------------
temp_per_month_asc = Window.partitionBy("mo").orderBy("count_temp")
#                                             ---------------------
#                                                     <2>

# end::code-window-orderby[]

# tag::code-window-rank[]

gsod_light.withColumn(
    "rank_tpm", F.rank().over(temp_per_month_asc)  # <1>
).show()
# +------+----+---+---+----+----------+--------+
# |   stn|year| mo| da|temp|count_temp|rank_tpm|
# +------+----+---+---+----+----------+--------+
# |949110|2019| 11| 23|54.9|        14|       1| <2>
# |996470|2018| 03| 12|55.6|        12|       1| <3>
# |998166|2019| 03| 20|34.8|        12|       1| <3>
# |998012|2017| 03| 02|31.4|        24|       3| <4>
# |041680|2019| 02| 19|16.1|        15|       1|
# |076470|2018| 06| 07|65.0|        24|       1|
# |719200|2017| 10| 09|60.5|        11|       1|
# |994979|2017| 12| 11|21.3|        21|       1|
# |917350|2018| 04| 21|82.6|         9|       1|
# |998252|2019| 04| 18|44.7|        11|       2|
# +------+----+---+---+----+----------+--------+

# end::code-window-rank[]

# tag::code-window-complete-rank[]

# +------+----+---+---+----+----------+--------+
# |   stn|year| mo| da|temp|count_temp|rank_tpm|
# +------+----+---+---+----+----------+--------+
# |949110|2019| 11| 23|54.9|        14|       1|
# |996470|2018| 03| 12|55.6|        12|       1| <-- One of those records
# |998166|2019| 03| 20|34.8|        12|       1| <--  would have a rank of 2
# |998012|2017| 03| 02|31.4|        24|       3|
# |041680|2019| 02| 19|16.1|        15|       1|
# |076470|2018| 06| 07|65.0|        24|       1|
# |719200|2017| 10| 09|60.5|        11|       1|
# |994979|2017| 12| 11|21.3|        21|       1|
# |917350|2018| 04| 21|82.6|         9|       1|
# |998252|2019| 04| 18|44.7|        11|       2|
# +------+----+---+---+----+----------+--------+

# end::code-window-complete-rank[]

# tag::code-window-dense-rank[]

gsod_light.withColumn(
    "rank_tpm", F.dense_rank().over(temp_per_month_asc)
).show()
#               --------------
#                    <1>

# +------+----+---+---+----+----------+--------+
# |   stn|year| mo| da|temp|count_temp|rank_tpm|
# +------+----+---+---+----+----------+--------+
# |949110|2019| 11| 23|54.9|        14|       1|
# |996470|2018| 03| 12|55.6|        12|       1| <2>
# |998166|2019| 03| 20|34.8|        12|       1| <2>
# |998012|2017| 03| 02|31.4|        24|       2| <3>
# |041680|2019| 02| 19|16.1|        15|       1|
# |076470|2018| 06| 07|65.0|        24|       1|
# |719200|2017| 10| 09|60.5|        11|       1|
# |994979|2017| 12| 11|21.3|        21|       1|
# |917350|2018| 04| 21|82.6|         9|       1|
# |998252|2019| 04| 18|44.7|        11|       2|
# +------+----+---+---+----+----------+--------+


# end::code-window-dense-rank[]

# tag::code-window-percent-rank[]

temp_each_year = each_year.orderBy("temp")
#                ---------
#                   <1>

gsod_light.withColumn(
    "rank_tpm", F.percent_rank().over(temp_each_year)
).show()

# +------+----+---+---+----+----------+------------------+
# |   stn|year| mo| da|temp|count_temp|          rank_tpm|
# +------+----+---+---+----+----------+------------------+
# |041680|2019| 02| 19|16.1|        15|               0.0|
# |998166|2019| 03| 20|34.8|        12|0.3333333333333333|
# |998252|2019| 04| 18|44.7|        11|0.6666666666666666| <2>
# |949110|2019| 11| 23|54.9|        14|               1.0|
# |994979|2017| 12| 11|21.3|        21|               0.0|
# |998012|2017| 03| 02|31.4|        24|               0.5|
# |719200|2017| 10| 09|60.5|        11|               1.0|
# |996470|2018| 03| 12|55.6|        12|               0.0|
# |076470|2018| 06| 07|65.0|        24|               0.5|
# |917350|2018| 04| 21|82.6|         9|               1.0|
# +------+----+---+---+----+----------+------------------+

# end::code-window-percent-rank[]


# tag::code-window-ntile[]

gsod_light.withColumn("rank_tpm", F.ntile(2).over(temp_each_year)).show()

# +------+----+---+---+----+----------+--------+
# |   stn|year| mo| da|temp|count_temp|rank_tpm|
# +------+----+---+---+----+----------+--------+
# |041680|2019| 02| 19|16.1|        15|       1|
# |998166|2019| 03| 20|34.8|        12|       1|
# |998252|2019| 04| 18|44.7|        11|       2|
# |949110|2019| 11| 23|54.9|        14|       2|
# |994979|2017| 12| 11|21.3|        21|       1|
# |998012|2017| 03| 02|31.4|        24|       1|
# |719200|2017| 10| 09|60.5|        11|       2|
# |996470|2018| 03| 12|55.6|        12|       1|
# |076470|2018| 06| 07|65.0|        24|       1|
# |917350|2018| 04| 21|82.6|         9|       2|
# +------+----+---+---+----+----------+--------+

# end::code-window-ntile[]

# tag::code-window-row-number[]

gsod_light.withColumn(
    "rank_tpm", F.row_number().over(temp_each_year)
).show()

# +------+----+---+---+----+----------+--------+
# |   stn|year| mo| da|temp|count_temp|rank_tpm|
# +------+----+---+---+----+----------+--------+
# |041680|2019| 02| 19|16.1|        15|       1| <1>
# |998166|2019| 03| 20|34.8|        12|       2| <1>
# |998252|2019| 04| 18|44.7|        11|       3| <1>
# |949110|2019| 11| 23|54.9|        14|       4| <1>
# |994979|2017| 12| 11|21.3|        21|       1|
# |998012|2017| 03| 02|31.4|        24|       2|
# |719200|2017| 10| 09|60.5|        11|       3|
# |996470|2018| 03| 12|55.6|        12|       1|
# |076470|2018| 06| 07|65.0|        24|       2|
# |917350|2018| 04| 21|82.6|         9|       3|
# +------+----+---+---+----+----------+--------+

# end::code-window-row-number[]


# tag::code-window-desc[]

temp_per_month_desc = Window.partitionBy("mo").orderBy(
    F.col("count_temp").desc()  # <1>
)

gsod_light.withColumn(
    "row_number", F.row_number().over(temp_per_month_desc)
).show()

# +------+----+---+---+----+----------+----------+
# |   stn|year| mo| da|temp|count_temp|row_number|
# +------+----+---+---+----+----------+----------+
# |949110|2019| 11| 23|54.9|        14|         1|
# |998012|2017| 03| 02|31.4|        24|         1|
# |996470|2018| 03| 12|55.6|        12|         2|
# |998166|2019| 03| 20|34.8|        12|         3|
# |041680|2019| 02| 19|16.1|        15|         1|
# |076470|2018| 06| 07|65.0|        24|         1|
# |719200|2017| 10| 09|60.5|        11|         1|
# |994979|2017| 12| 11|21.3|        21|         1|
# |998252|2019| 04| 18|44.7|        11|         1|
# |917350|2018| 04| 21|82.6|         9|         2|
# +------+----+---+---+----+----------+----------+

# end::code-window-desc[]


# tag::code-window-lag[]

gsod_light.withColumn(
    "previous_temp", F.lag("temp").over(temp_each_year)
).withColumn(
    "previous_temp_2", F.lag("temp", 2).over(temp_each_year)
).show()

# +------+----+---+---+----+----------+-------------+---------------+
# |   stn|year| mo| da|temp|count_temp|previous_temp|previous_temp_2|
# +------+----+---+---+----+----------+-------------+---------------+
# |041680|2019| 02| 19|16.1|        15|         null|           null|
# |998166|2019| 03| 20|34.8|        12|         16.1|           null|  <1>
# |998252|2019| 04| 18|44.7|        11|         34.8|           16.1|  <1>
# |949110|2019| 11| 23|54.9|        14|         44.7|           34.8|
# |994979|2017| 12| 11|21.3|        21|         null|           null|
# |998012|2017| 03| 02|31.4|        24|         21.3|           null|
# |719200|2017| 10| 09|60.5|        11|         31.4|           21.3|
# |996470|2018| 03| 12|55.6|        12|         null|           null|
# |076470|2018| 06| 07|65.0|        24|         55.6|           null|
# |917350|2018| 04| 21|82.6|         9|         65.0|           55.6|
# +------+----+---+---+----+----------+-------------+---------------+

# end::code-window-lag[]

# tag::code-window-cume-dist[]

gsod_light.withColumn(
    "percent_rank", F.percent_rank().over(temp_each_year)
).withColumn("cume_dist", F.cume_dist().over(temp_each_year)).show()

# +------+----+---+---+----+----------+----------------+----------------+
# |   stn|year| mo| da|temp|count_temp|    percent_rank|       cume_dist|
# +------+----+---+---+----+----------+----------------+----------------+
# |041680|2019| 02| 19|16.1|        15|             0.0|            0.25|
# |998166|2019| 03| 20|34.8|        12|0.33333333333333|             0.5|
# |998252|2019| 04| 18|44.7|        11|0.66666666666666|            0.75|
# |949110|2019| 11| 23|54.9|        14|             1.0|             1.0|
# |994979|2017| 12| 11|21.3|        21|             0.0|0.33333333333333|
# |998012|2017| 03| 02|31.4|        24|             0.5|0.66666666666666|
# |719200|2017| 10| 09|60.5|        11|             1.0|             1.0|
# |996470|2018| 03| 12|55.6|        12|             0.0|0.33333333333333|
# |076470|2018| 06| 07|65.0|        24|             0.5|0.66666666666666|
# |917350|2018| 04| 21|82.6|         9|             1.0|             1.0|
# +------+----+---+---+----+----------+----------------+----------------+


# end::code-window-cume-dist[]


# tag::code-window-ordered-and-not[]

not_ordered = Window.partitionBy("year")
ordered = not_ordered.orderBy("temp")

gsod_light.withColumn(
    "avg_NO", F.avg("temp").over(not_ordered)
).withColumn("avg_O", F.avg("temp").over(ordered)).show()

# +------+----+---+---+----+----------+----------------+------------------+
# |   stn|year| mo| da|temp|count_temp|          avg_NO|             avg_O|
# +------+----+---+---+----+----------+----------------+------------------+
# |041680|2019| 02| 19|16.1|        15|          37.625|              16.1|
# |998166|2019| 03| 20|34.8|        12|          37.625|             25.45|
# |998252|2019| 04| 18|44.7|        11|          37.625|31.866666666666664|
# |949110|2019| 11| 23|54.9|        14|          37.625|            37.625|
# |994979|2017| 12| 11|21.3|        21|37.7333333333334|              21.3|
# |998012|2017| 03| 02|31.4|        24|37.7333333333334|             26.35|
# |719200|2017| 10| 09|60.5|        11|37.7333333333334|37.733333333333334|
# |996470|2018| 03| 12|55.6|        12| 67.733333333333|              55.6|
# |076470|2018| 06| 07|65.0|        24| 67.733333333333|              60.3|
# |917350|2018| 04| 21|82.6|         9| 67.733333333333| 67.73333333333333|
# +------+----+---+---+----+----------+----------------+------------------+
#                                             <1>
#                                                                <2>

# end::code-window-ordered-and-not[]

# tag::code-window-ordered-and-not2[]
not_ordered = Window.partitionBy("year").rowsBetween(
    Window.unboundedPreceding, Window.unboundedFollowing  # <1>
)
ordered = not_ordered.orderBy("temp").rangeBetween(
    Window.unboundedPreceding, Window.currentRow  # <2>
)
# end::code-window-ordered-and-not2[]

# tag::code-window-date[]

gsod_light_p = (
    gsod_light.withColumn("year", F.lit(2019))
    .withColumn(
        "dt",
        F.to_date(
            F.concat_ws("-", F.col("year"), F.col("mo"), F.col("da"))
        ),
    )
    .withColumn("dt_num", F.unix_timestamp("dt"))
)
gsod_light_p.show()

#                                          <1>
# +------+----+---+---+----+----------+----------+----------+
# |   stn|year| mo| da|temp|count_temp|        dt|    dt_num|
# +------+----+---+---+----+----------+----------+----------+
# |041680|2019| 02| 19|16.1|        15|2019-02-19|1550552400|
# |998012|2019| 03| 02|31.4|        24|2019-03-02|1551502800|
# |996470|2019| 03| 12|55.6|        12|2019-03-12|1552363200|
# |998166|2019| 03| 20|34.8|        12|2019-03-20|1553054400|
# |998252|2019| 04| 18|44.7|        11|2019-04-18|1555560000|
# |917350|2019| 04| 21|82.6|         9|2019-04-21|1555819200|
# |076470|2019| 06| 07|65.0|        24|2019-06-07|1559880000|
# |719200|2019| 10| 09|60.5|        11|2019-10-09|1570593600|
# |949110|2019| 11| 23|54.9|        14|2019-11-23|1574485200|
# |994979|2019| 12| 11|21.3|        21|2019-12-11|1576040400|
# +------+----+---+---+----+----------+----------+----------+
#                                                    <2>

# end::code-window-date[]

# tag::code-window-range-one-month[]
ONE_MONTH_ISH = 30 * 60 * 60 * 24  # or 2_592_000 seconds
one_month_ish_before_and_after = (
    Window.partitionBy("year")
    .orderBy("dt_num")
    .rangeBetween(-ONE_MONTH_ISH, ONE_MONTH_ISH)  # <1>
)

gsod_light_p.withColumn(
    "avg_count", F.avg("count_temp").over(one_month_ish_before_and_after)
).show()

# +------+----+---+---+----+----------+----------+----------+-------------+
# |   stn|year| mo| da|temp|count_temp|        dt|    dt_num|    avg_count|
# +------+----+---+---+----+----------+----------+----------+-------------+
# |041680|2019| 02| 19|16.1|        15|2019-02-19|1550552400|        15.75|
# |998012|2019| 03| 02|31.4|        24|2019-03-02|1551502800|        15.75|
# |996470|2019| 03| 12|55.6|        12|2019-03-12|1552363200|        15.75|
# |998166|2019| 03| 20|34.8|        12|2019-03-20|1553054400|         14.8|
# |998252|2019| 04| 18|44.7|        11|2019-04-18|1555560000|10.6666666666|
# |917350|2019| 04| 21|82.6|         9|2019-04-21|1555819200|         10.0|
# |076470|2019| 06| 07|65.0|        24|2019-06-07|1559880000|         24.0|
# |719200|2019| 10| 09|60.5|        11|2019-10-09|1570593600|         11.0|
# |949110|2019| 11| 23|54.9|        14|2019-11-23|1574485200|         17.5|
# |994979|2019| 12| 11|21.3|        21|2019-12-11|1576040400|         17.5|
# +------+----+---+---+----+----------+----------+----------+-------------+


# end::code-window-range-one-month[]

# tag::code-window-udf[]

import pandas as pd


# Spark 2.4, use the following
# @F.pandas_udf("double", PandasUDFType.GROUPED_AGG)
@F.pandas_udf("double")
def median(vals: pd.Series) -> float:
    return vals.median()


gsod_light.withColumn(
    "median_temp", median("temp").over(Window.partitionBy("year"))  # <1>
).withColumn(
    "median_temp_g",
    median("temp").over(
        Window.partitionBy("year").orderBy("mo", "da")  # <2>
    ),  # <2>
).show()

#                                          <3>
# +------+----+---+---+----+----------+-----------+-------------+
# |   stn|year| mo| da|temp|count_temp|median_temp|median_temp_g|
# +------+----+---+---+----+----------+-----------+-------------+
# |041680|2019| 02| 19|16.1|        15|      39.75|         16.1|
# |998166|2019| 03| 20|34.8|        12|      39.75|        25.45|
# |998252|2019| 04| 18|44.7|        11|      39.75|         34.8|
# |949110|2019| 11| 23|54.9|        14|      39.75|        39.75|
# |998012|2017| 03| 02|31.4|        24|       31.4|         31.4|
# |719200|2017| 10| 09|60.5|        11|       31.4|        45.95|
# |994979|2017| 12| 11|21.3|        21|       31.4|         31.4|
# |996470|2018| 03| 12|55.6|        12|       65.0|         55.6|
# |917350|2018| 04| 21|82.6|         9|       65.0|         69.1|
# |076470|2018| 06| 07|65.0|        24|       65.0|         65.0|
# +------+----+---+---+----+----------+-----------+-------------+
#                                                       <4>


# end::code-window-udf[]
