# /usr/bin/env python3

# pylint: disable=C0413,C0411,C0116

from pyspark.sql import SparkSession
from functools import reduce
import pyspark.sql.functions as F

spark = SparkSession.builder.config(
    "spark.driver.memory", "8g"
).getOrCreate()

# tag::ch09-read-local-alternate[]

gsod = (
    reduce(
        lambda x, y: x.unionByName(y, allowMissingColumns=True),
        [
            spark.read.parquet(f"./data/gsod_noaa/gsod{year}.parquet")
            for year in range(2010, 2021)
        ],
    )
    .dropna(subset=["year", "mo", "da", "temp"])
    .where(F.col("temp") != 9999.9)
    .drop("date")
)

# end::ch09-read-local-alternate[]

# Only keeping 2018 to keep the processing faster when running the whole code.
gsod = gsod.where("year = 2018")

# tag::ch09-scalar-udf[]

import pandas as pd
import pyspark.sql.types as T


@F.pandas_udf(T.DoubleType())  # <1>
def f_to_c(degrees: pd.Series) -> pd.Series:  # <2>
    """Transforms Farhenheit to Celcius."""
    return (degrees - 32) * 5 / 9


# end::ch09-scalar-udf[]

# tag::ch09-scalar-udf-application[]


gsod = gsod.withColumn("temp_c", f_to_c(F.col("temp")))
gsod.select("temp", "temp_c").distinct().show(5)

# +-----+-------------------+
# | temp|             temp_c|
# +-----+-------------------+
# | 37.2| 2.8888888888888906|
# | 85.9| 29.944444444444443|
# | 53.5| 11.944444444444445|
# | 71.6| 21.999999999999996|
# |-27.6|-33.111111111111114|
# +-----+-------------------+
# only showing top 5 rows

# end::ch09-scalar-udf-application[]

# tag::ch09-iterator-series-udf[]

from time import sleep
from typing import Iterator


@F.pandas_udf(T.DoubleType())
def f_to_c2(degrees: Iterator[pd.Series]) -> Iterator[pd.Series]:  # <1>
    """Transforms Farhenheit to Celcius."""
    sleep(5)  # <2>
    for batch in degrees:  # <3>
        yield (batch - 32) * 5 / 9  # <3>


gsod.select(
    "temp", f_to_c2(F.col("temp")).alias("temp_c")
).distinct().show(5)
# +-----+-------------------+
# | temp|             temp_c|
# +-----+-------------------+
# | 37.2| 2.8888888888888906|
# | 85.9| 29.944444444444443|
# | 53.5| 11.944444444444445|
# | 71.6| 21.999999999999996|
# |-27.6|-33.111111111111114|
# +-----+-------------------+
# only showing top 5 rows

# end::ch09-iterator-series-udf[]

# fmt:off
# tag::ch09-iterator-multiple-series-udf[]

from typing import Tuple


@F.pandas_udf(T.DateType())
def create_date(
    year_mo_da: Iterator[Tuple[pd.Series, pd.Series, pd.Series]]
) -> Iterator[pd.Series]:
    """Merges three cols (representing Y-M-D of a date) into a Date col."""
    for year, mo, da in year_mo_da:
        yield pd.to_datetime(
            pd.DataFrame(dict(year=year, month=mo, day=da))
        )


gsod.select(
    "year", "mo", "da",
    create_date(F.col("year"), F.col("mo"), F.col("da")).alias("date"),
).distinct().show(5)

# end::ch09-iterator-multiple-series-udf[]
# fmt:on

# Used to be called Grouped Aggregate
# tag::ch09-series-to-scalar[]
from sklearn.linear_model import LinearRegression  # <1>


@F.pandas_udf(T.DoubleType())
def rate_of_change_temperature(day: pd.Series, temp: pd.Series) -> float:
    """Returns the slope of the daily temperature for a given period of time."""
    return (
        LinearRegression()  # <2>
        .fit(X=day.astype(int).values.reshape(-1, 1), y=temp)  # <3>
        .coef_[0]  # <4>
    )


# end::ch09-series-to-scalar[]


# tag::ch09-series-to-scalar-agg[]

result = gsod.groupby("stn", "year", "mo").agg(
    rate_of_change_temperature(gsod["da"], gsod["temp"]).alias(  # <1>
        "rt_chg_temp"
    )
)

result.show(5, False)
# +------+----+---+---------------------+
# |stn   |year|mo |rt_chg_temp          |
# +------+----+---+---------------------+
# |010250|2018|12 |-0.01014397905759162 |
# |011120|2018|11 |-0.01704736746691528 |
# |011150|2018|10 |-0.013510329829648423|
# |011510|2018|03 |0.020159116598556657 |
# |011800|2018|06 |0.012645501680677372 |
# +------+----+---+---------------------+
# only showing top 5 rows

# end::ch09-series-to-scalar-agg[]


# tag::ch09-grouped-map-udf-verbose[]


def scale_temperature(temp_by_day: pd.DataFrame) -> pd.DataFrame:
    """Returns a simple normalization of the temperature for a site.

    If the temperature is constant for the whole window, defaults to 0.5."""
    temp = temp_by_day.temp
    answer = temp_by_day[["stn", "year", "mo", "da", "temp"]]
    if temp.min() == temp.max():
        return answer.assign(temp_norm=0.5)
    return answer.assign(
        temp_norm=(temp - temp.min()) / (temp.max() - temp.min())
    )


# end::ch09-grouped-map-udf-verbose[]

# tag::ch09-grouped-map-udf[]

gsod_map = gsod.groupby("stn", "year", "mo").applyInPandas(
    scale_temperature,
    schema=(
        "stn string, year string, mo string, "
        "da string, temp double, temp_norm double"
    ),
)

gsod_map.show(5, False)
# +------+----+---+---+----+-------------------+
# |stn   |year|mo |da |temp|temp_norm          |
# +------+----+---+---+----+-------------------+
# |010250|2018|12 |08 |21.8|0.06282722513089001|
# |010250|2018|12 |27 |28.3|0.40314136125654443|
# |010250|2018|12 |31 |29.1|0.4450261780104712 |
# |010250|2018|12 |19 |27.6|0.36649214659685864|
# |010250|2018|12 |04 |36.6|0.8376963350785339 |
# +------+----+---+---+----+-------------------+

# end::ch09-grouped-map-udf[]

# result.groupby("stn").agg(
#     F.sum(F.when(F.col("rt_chg_temp") > 0, 1).otherwise(0)).alias(
#         "temp_increasing"
#     ),
#     F.count("rt_chg_temp").alias("count"),
# ).where(F.col("count") > 6).select(
#     F.col("stn"),
#     (F.col("temp_increasing") / F.col("count")).alias(
#         "temp_increasing_ratio"
#     ),
# ).orderBy(
#     "temp_increasing_ratio"
# ).show(
#     5, False
# )
# # +------+---------------------+  <2>
# # |stn   |temp_increasing_ratio|
# # +------+---------------------+
# # |681115|0.0                  |
# # |384572|0.0                  |
# # |682720|0.0                  |
# # |672310|0.0                  |
# # |654530|0.08333333333333333  |
# # +------+---------------------+
# # only showing top 5 rows
# # end::ch09-agg[]

# tag::ch09-local[]
gsod_local = gsod.where(
    "year = '2018' and mo = '08' and stn = '710920'"
).toPandas()


print(
    rate_of_change_temperature.func(
        gsod_local["da"], gsod_local["temp_norm"]
    )
)
# -0.007830974115511494
# end::ch09-local[]

# tag::appc-max[]
maxes = [F.max(x) for x in ["temp", "temp_norm"]]

gsod_map.groupby("stn").agg(*maxes).show(5)  # <1>

# +------+---------+--------------+
# |   stn|max(temp)|max(temp_norm)|
# +------+---------+--------------+
# |296450|     77.7|           1.0|
# |633320|     81.4|           1.0|
# |720375|     79.2|           1.0|
# |725165|     83.5|           1.0|
# |868770|     94.6|           1.0|
# +------+---------+--------------+
# only showing top 5 rows
# end::appc-max[]
