#!/usr/bin/env python3

# pylint: disable=C0413,C0411,C0116,C0114,E1101,W0621

import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.window import Window

spark = SparkSession.builder.getOrCreate()

gsod = spark.read.parquet("./data/window/gsod.parquet")

# tag::sol10_1[]

sol10_1 = Window.partitionBy("year", "mo", "da")

res10_1 = (
    gsod.select(
        "stn",
        "year",
        "mo",
        "da",
        "temp",
        F.max("temp").over(sol10_1).alias("max_this_day"),
    )
    .where(F.col("temp") == F.col("max_this_day"))
    .drop("temp")
)

res10_1.show(5)
# +------+----+---+---+------------+
# |   stn|year| mo| da|max_this_day|
# +------+----+---+---+------------+
# |406370|2017| 08| 11|       108.3|
# |672614|2017| 12| 10|        93.8|
# |944500|2018| 01| 04|        99.2|
# |954920|2018| 01| 12|        98.9|
# |647530|2018| 10| 01|       100.4|
# +------+----+---+---+------------+
# only showing top 5 rows

# end::sol10_1[]

# tag::sol10_2[]

exo10_2 = spark.createDataFrame(
    [[x // 4, 2] for x in range(1001)], ["index", "value"]
)

exo10_2.show()
# +-----+-----+
# |index|value|
# +-----+-----+
# |    0|    2|
# |    0|    2|
# |    0|    2|
# |    0|    2|
# |    1|    2|
# |    1|    2|
# |    1|    2|
# |    1|    2|
# |    2|    2|
# |    2|    2|
# |    2|    2|
# |    2|    2|
# |    3|    2|
# |    3|    2|
# |    3|    2|
# |    3|    2|
# |    4|    2|
# |    4|    2|
# |    4|    2|
# |    4|    2|
# +-----+-----+
# only showing top 20 rows

sol10_2 = Window.partitionBy("index").orderBy("value")

exo10_2.withColumn("10_2", F.ntile(3).over(sol10_2)).show(10)
# +-----+-----+----+
# |index|value|10_2|
# +-----+-----+----+
# |   26|    2|   1|
# |   26|    2|   1|
# |   26|    2|   2|
# |   26|    2|   3|
# |   29|    2|   1|
# |   29|    2|   1|
# |   29|    2|   2|
# |   29|    2|   3|
# |   65|    2|   1|
# |   65|    2|   1|
# +-----+-----+----+
# only showing top 10 rows


# end::sol10_2[]

# tag::sol10_3[]

exo10_3 = spark.createDataFrame([[10] for x in range(1_000_001)], ["ord"])

exo10_3.select(
    "ord",
    F.count("ord")
    .over(Window.partitionBy().orderBy("ord").rowsBetween(-2, 2))
    .alias("row"),
    F.count("ord")
    .over(Window.partitionBy().orderBy("ord").rangeBetween(-2, 2))
    .alias("range"),
).show(10)
# +---+---+-------+
# |ord|row|  range|
# +---+---+-------+
# | 10|  3|1000001|
# | 10|  4|1000001|
# | 10|  5|1000001|
# | 10|  5|1000001|
# | 10|  5|1000001|
# | 10|  5|1000001|
# | 10|  5|1000001|
# | 10|  5|1000001|
# | 10|  5|1000001|
# | 10|  5|1000001|
# +---+---+-------+
# only showing top 10 rows

# end::sol10_3[]

# fmt:off
# tag::exo10_4[]
each_year = Window.partitionBy("year")

(gsod
 .withColumn("min_temp", F.min("temp").over(each_year))
 .where("temp = min_temp")
 .select("year", "mo", "da", "stn", "temp")
 .orderBy("year", "mo", "da")
 .show())
# end::exo10_4[]
# fmt:on

# tag::sol10_4[]

(
    gsod.withColumn("max_temp", F.max("temp").over(each_year))
    .where("temp = max_temp")
    .select("year", "mo", "da", "stn", "temp")
    .withColumn("avg_temp", F.avg("temp").over(each_year))
    .orderBy("year", "stn")
    .show()
)
# +----+---+---+------+-----+--------+
# |year| mo| da|   stn| temp|avg_temp|
# +----+---+---+------+-----+--------+
# |2017| 07| 06|403770|110.0|   110.0|
# |2017| 07| 24|999999|110.0|   110.0|
# |2018| 06| 06|405860|110.0|   110.0|
# |2018| 07| 12|407036|110.0|   110.0|
# |2018| 07| 26|723805|110.0|   110.0|
# |2018| 07| 16|999999|110.0|   110.0|
# |2019| 07| 07|405870|110.0|   110.0|
# |2019| 07| 15|606030|110.0|   110.0|
# |2019| 08| 02|606450|110.0|   110.0|
# |2019| 07| 14|999999|110.0|   110.0|
# +----+---+---+------+-----+--------+

# end::sol10_4[]

# tag::exo10_5[]

temp_per_month_asc = Window.partitionBy("mo").orderBy("count_temp")

gsod_light = spark.read.parquet("./data/window/gsod_light.parquet")
gsod_light.withColumn(
    "rank_tpm", F.rank().over(temp_per_month_asc)  # <1>
).show()

# +------+----+---+---+----+----------+--------+
# |   stn|year| mo| da|temp|count_temp|rank_tpm|
# +------+----+---+---+----+----------+--------+
# |949110|2019| 11| 23|54.9|        14|       1|
# |996470|2018| 03| 12|55.6|        12|       1| <1>
# |998166|2019| 03| 20|34.8|        12|       1| <1>
# |998012|2017| 03| 02|31.4|        24|       3|
# |041680|2019| 02| 19|16.1|        15|       1|
# |076470|2018| 06| 07|65.0|        24|       1|
# |719200|2017| 10| 09|60.5|        11|       1|
# |994979|2017| 12| 11|21.3|        21|       1|
# |917350|2018| 04| 21|82.6|         9|       1|
# |998252|2019| 04| 18|44.7|        11|       2|
# +------+----+---+---+----+----------+--------+

# end::exo10_5[]

temp_per_month_asc = Window.partitionBy("mo").orderBy("count_temp")
temp_per_month_rnk = Window.partitionBy("mo").orderBy(
    "count_temp", "row_tpm"
)

gsod_light.withColumn(
    "row_tpm", F.row_number().over(temp_per_month_asc)
).withColumn("rank_tpm", F.rank().over(temp_per_month_rnk)).show()
# +------+----+---+---+----+----------+-------+--------+
# |   stn|year| mo| da|temp|count_temp|row_tpm|rank_tpm|
# +------+----+---+---+----+----------+-------+--------+
# |949110|2019| 11| 23|54.9|        14|      1|       1|
# |996470|2018| 03| 12|55.6|        12|      1|       1| <1>
# |998166|2019| 03| 20|34.8|        12|      2|       2| <1>
# |998012|2017| 03| 02|31.4|        24|      3|       3|
# |041680|2019| 02| 19|16.1|        15|      1|       1|
# |076470|2018| 06| 07|65.0|        24|      1|       1|
# |719200|2017| 10| 09|60.5|        11|      1|       1|
# |994979|2017| 12| 11|21.3|        21|      1|       1|
# |917350|2018| 04| 21|82.6|         9|      1|       1|
# |998252|2019| 04| 18|44.7|        11|      2|       2|
# +------+----+---+---+----+----------+-------+--------+
# tag::sol10_5[]


# end::sol10_5[]

# tag::sol10_6[]

seven_days = (
    Window.partitionBy("stn")
    .orderBy("dtu")
    .rangeBetween(-7 * 60 * 60 * 24, 7 * 60 * 60 * 24)
)
sol10_6 = (
    gsod.select(
        "stn",
        (F.to_date(F.concat_ws("-", "year", "mo", "da"))).alias("dt"),
        "temp",
    )
    .withColumn("dtu", F.unix_timestamp("dt").alias("dtu"))
    .withColumn("max_temp", F.max("temp").over(seven_days))
    .where("temp = max_temp")
    .show(10)
)
# +------+----------+----+----------+--------+
# |   stn|        dt|temp|       dtu|max_temp|
# +------+----------+----+----------+--------+
# |010875|2017-01-08|46.2|1483851600|    46.2|
# |010875|2017-01-19|48.0|1484802000|    48.0|
# |010875|2017-02-03|45.3|1486098000|    45.3|
# |010875|2017-02-20|45.7|1487566800|    45.7|
# |010875|2017-03-14|45.7|1489464000|    45.7|
# |010875|2017-04-01|46.8|1491019200|    46.8|
# |010875|2017-04-20|46.1|1492660800|    46.1|
# |010875|2017-05-02|50.5|1493697600|    50.5|
# |010875|2017-05-27|51.4|1495857600|    51.4|
# |010875|2017-06-06|53.6|1496721600|    53.6|
# +------+----------+----+----------+--------+
# only showing top 10 rows

# end::sol10_6[]

# tag::exo10_7[]

ONE_MONTH_ISH = 30 * 60 * 60 * 24  # or 2_592_000 seconds
one_month_ish_before_and_after = (
    Window.partitionBy("year")
    .orderBy("dt_num")
    .rangeBetween(-ONE_MONTH_ISH, ONE_MONTH_ISH)  # <1>
)

gsod_light_p = (
    gsod_light.withColumn("year", F.lit(2019))
    .withColumn(
        "dt",
        F.to_date(
            F.concat_ws("-", F.col("year"), F.col("mo"), F.col("da"))
        ),
    )
    .withColumn("dt_num", F.unix_timestamp("dt"))
)

gsod_light_p.withColumn(
    "avg_count", F.avg("count_temp").over(one_month_ish_before_and_after)
).show()

# +------+----+---+---+----+----------+----------+----------+-------------+
# |   stn|year| mo| da|temp|count_temp|        dt|    dt_num|    avg_count|
# +------+----+---+---+----+----------+----------+----------+-------------+
# |041680|2019| 02| 19|16.1|        15|2019-02-19|1550552400|        15.75|
# |998012|2019| 03| 02|31.4|        24|2019-03-02|1551502800|        15.75|
# |996470|2019| 03| 12|55.6|        12|2019-03-12|1552363200|        15.75|
# |998166|2019| 03| 20|34.8|        12|2019-03-20|1553054400|         14.8|
# |998252|2019| 04| 18|44.7|        11|2019-04-18|1555560000|10.6666666666|
# |917350|2019| 04| 21|82.6|         9|2019-04-21|1555819200|         10.0|
# |076470|2019| 06| 07|65.0|        24|2019-06-07|1559880000|         24.0|
# |719200|2019| 10| 09|60.5|        11|2019-10-09|1570593600|         11.0|
# |949110|2019| 11| 23|54.9|        14|2019-11-23|1574485200|         17.5|
# |994979|2019| 12| 11|21.3|        21|2019-12-11|1576040400|         17.5|
# +------+----+---+---+----+----------+----------+----------+-------------+

# end::exo10_7[]

# tag::sol10_7[]

one_month_before_and_after = (
    Window.partitionBy("year").orderBy("num_mo").rangeBetween(-1, 1)
)

gsod_light_p.drop("dt", "dt_num").withColumn(
    "num_mo", F.col("year").cast("int") * 12 + F.col("mo").cast("int")
).withColumn(
    "avg_count", F.avg("count_temp").over(one_month_before_and_after)
).show()
# +------+----+---+---+----+----------+------+------------------+
# |   stn|year| mo| da|temp|count_temp|num_mo|         avg_count|
# +------+----+---+---+----+----------+------+------------------+
# |041680|2019| 02| 19|16.1|        15| 24230|             15.75|
# |998012|2019| 03| 02|31.4|        24| 24231|13.833333333333334|
# |996470|2019| 03| 12|55.6|        12| 24231|13.833333333333334|
# |998166|2019| 03| 20|34.8|        12| 24231|13.833333333333334|
# |917350|2019| 04| 21|82.6|         9| 24232|              13.6|
# |998252|2019| 04| 18|44.7|        11| 24232|              13.6|
# |076470|2019| 06| 07|65.0|        24| 24234|              24.0|
# |719200|2019| 10| 09|60.5|        11| 24238|              12.5|
# |949110|2019| 11| 23|54.9|        14| 24239|15.333333333333334|
# |994979|2019| 12| 11|21.3|        21| 24240|              17.5|
# +------+----+---+---+----+----------+------+------------------+


# end::sol10_7[]


# # tag::code-window-exo-total[]

# mystery_window = Window.partitionBy()

# gsod.withColumn("min_temp", F.min("temp").over(mystery_window)).where(
#     "temp = min_temp"
# ).select("year", "mo", "da", "stn", "temp").orderBy(
#     "year", "mo", "da"
# ).show()

# # end::code-window-exo-total[]
