# /usr/bin/env python3

# pylint: disable=C0413,C0411,C0116

from pyspark.sql import SparkSession
from functools import reduce
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pandas as pd

spark = SparkSession.builder.config(
    "spark.driver.memory", "8g"
).getOrCreate()

gsod = (
    reduce(
        lambda x, y: x.unionByName(y, allowMissingColumns=True),
        [
            spark.read.parquet(f"./data/gsod_noaa/gsod{year}.parquet")
            for year in range(2010, 2021)
        ],
    )
    .dropna(subset=["year", "mo", "da", "temp"])
    .where(F.col("temp") != 9999.9)
    .drop("date")
)

# tag::sol9_1[]
WHICH_TYPE = T.IntegerType()
WHICH_SIGNATURE = pd.Series
# end::sol9_1[]

# tag::exo9_1[]

exo9_1 = pd.Series(["red", "blue", "blue", "yellow"])


def color_to_num(colors: WHICH_SIGNATURE) -> WHICH_SIGNATURE:
    return colors.apply(
        lambda x: {"red": 1, "blue": 2, "yellow": 3}.get(x)
    )


color_to_num(exo9_1)

# 0    1
# 1    2
# 2    2
# 3    3

color_to_num_udf = F.pandas_udf(color_to_num, WHICH_TYPE)

# end::exo9_1[]


# tag::sol9_2[]


def temp_to_temp(
    value: pd.Series, from_temp: str, to_temp: str
) -> pd.Series:

    acceptable_values = ["F", "C", "R", "K"]
    if (
        to_temp not in acceptable_values
        or from_temp not in acceptable_values
    ):
        return value.apply(lambda _: None)

    def f_to_c(value):
        return (value - 32.0) * 5.0 / 9.0

    def c_to_f(value):
        return value * 9.0 / 5.0 + 32.0

    K_OVER_C = 273.15
    R_OVER_F = 459.67

    # We can reduce our decision tree by only converting from C and F
    if from_temp == "K":
        value -= K_OVER_C
        from_temp = "C"
    if from_temp == "R":
        value -= R_OVER_F
        from_temp = "F"

    if from_temp == "C":
        if to_temp == "C":
            return value
        if to_temp == "F":
            return c_to_f(value)
        if to_temp == "K":
            return value + K_OVER_C
        if to_temp == "R":
            return c_to_f(value) + R_OVER_F
    else:  # from_temp == "F":
        if to_temp == "C":
            return f_to_c(value)
        if to_temp == "F":
            return value
        if to_temp == "K":
            return f_to_c(value) + K_OVER_C
        if to_temp == "R":
            return value + R_OVER_F


# end::sol9_2[]

# tag::sol9_3[]


def scale_temperature_C(temp_by_day: pd.DataFrame) -> pd.DataFrame:
    """Returns a simple normalization of the temperature for a site, in Celcius.

    If the temperature is constant for the whole window, defaults to 0.5."""

    def f_to_c(temp):
        return (temp - 32.0) * 5.0 / 9.0

    temp = f_to_c(temp_by_day.temp)
    answer = temp_by_day[["stn", "year", "mo", "da", "temp"]]
    if temp.min() == temp.max():
        return answer.assign(temp_norm=0.5)
    return answer.assign(
        temp_norm=(temp - temp.min()) / (temp.max() - temp.min())
    )

# end::sol9_3[]

# tag::sol9_4[]

sol9_4 = gsod.groupby("year", "mo").applyInPandas(
    scale_temperature_C,
    schema=(
        "year string, mo string, "
        "temp double, temp_norm double"
    ),
)

try:
    sol9_4.show(5, False)
except RuntimeError as err:
    print(err)

# RuntimeError: Number of columns of the returned pandas.DataFrame doesn't match
# specified schema. Expected: 4 Actual: 6

# end::sol9_4[]

# tag::exo9_5[]

from sklearn.linear_model import LinearRegression


@F.pandas_udf(T.DoubleType())
def rate_of_change_temperature(day: pd.Series, temp: pd.Series) -> float:
    """Returns the slope of the daily temperature for a given period of time."""
    return (
        LinearRegression()
        .fit(X=day.astype("int").values.reshape(-1, 1), y=temp)
        .coef_[0]
    )


# end::exo9_5[]

# tag::sol9_5[]

from sklearn.linear_model import LinearRegression
from typing import Sequence


@F.pandas_udf(T.ArrayType(T.DoubleType()))
def rate_of_change_temperature_ic(
    day: pd.Series, temp: pd.Series
) -> Sequence[float]:
    """Returns the intercept and slope of the daily temperature for a given period of time."""
    model = LinearRegression().fit(
        X=day.astype(int).values.reshape(-1, 1), y=temp
    )
    return model.intercept_, model.coef_[0]


gsod.groupby("stn", "year", "mo").agg(
    rate_of_change_temperature_ic("da", "temp").alias("sol9_5")
).show(5, truncate=50)
# +------+----+---+------------------------------------------+
# |   stn|year| mo|                                    sol9_5|
# +------+----+---+------------------------------------------+
# |008268|2010| 07| [135.79999999999973, -2.1999999999999877]|
# |008401|2011| 11| [67.51655172413793, -0.30429365962180205]|
# |008411|2014| 02| [82.69682539682537, -0.02662835249042155]|
# |008411|2015| 12|  [84.03264367816091, -0.0476974416017797]|
# |008415|2016| 01|[82.10193548387099, -0.013225806451612926]|
# +------+----+---+------------------------------------------+
# only showing top 5 rows


# end::sol9_5[]
