#!/usr/bin/env python3

from pyspark.sql import SparkSession
from pyspark.sql import SparkSession
from functools import reduce
import pyspark.sql.functions as F
import pyspark.sql.types as T

spark = SparkSession.builder.config(
    "spark.driver.memory", "8g"
).getOrCreate()
sc = spark.sparkContext

gsod = (
    reduce(
        lambda x, y: x.unionByName(y, allowMissingColumns=True),
        [
            spark.read.parquet(f"./data/gsod_noaa/gsod{year}.parquet")
            for year in range(2010, 2021)
        ],
    )
    .dropna(subset=["year", "mo", "da", "temp"])
    .where(F.col("temp") != 9999.9)
    .drop("date")
)

# tag::sol8_1[]
exo_rdd = spark.sparkContext.parallelize(list(range(100)))

from operator import add

sol8_1 = exo_rdd.map(lambda _: 1).reduce(add)
print(sol8_1)  # => 100
# end::sol8_1[]

# tag::exo8_2[]
a_rdd = sc.parallelize([0, 1, None, [], 0.0])

a_rdd.filter(lambda x: x).collect()
# end::exo8_2[]

# tag::sol8_3[]
from typing import Optional


@F.udf(T.DoubleType())
def temp_to_temp(
    value: float, from_temp: str, to_temp: str
) -> Optional[float]:

    acceptable_values = ["F", "C", "R", "K"]
    if (
        to_temp not in acceptable_values
        or from_temp not in acceptable_values
    ):
        return None

    def f_to_c(value):
        return (value - 32.0) * 5.0 / 9.0

    def c_to_f(value):
        return value * 9.0 / 5.0 + 32.0

    K_OVER_C = 273.15
    R_OVER_F = 459.67

    # We can reduce our decision tree by only converting from C and F
    if from_temp == "K":
        value -= K_OVER_C
        from_temp = "C"
    if from_temp == "R":
        value -= R_OVER_F
        from_temp = "F"

    if from_temp == "C":
        if to_temp == "C":
            return value
        if to_temp == "F":
            return c_to_f(value)
        if to_temp == "K":
            return value + K_OVER_C
        if to_temp == "R":
            return c_to_f(value) + R_OVER_F
    else:  # from_temp == "F":
        if to_temp == "C":
            return f_to_c(value)
        if to_temp == "F":
            return value
        if to_temp == "K":
            return f_to_c(value) + K_OVER_C
        if to_temp == "R":
            return value + R_OVER_F


sol8_3 = gsod.select(
    "stn",
    "year",
    "mo",
    "da",
    "temp",
    temp_to_temp("temp", F.lit("F"), F.lit("K")),
)
sol8_3.show(5)
# +------+----+---+---+----+------------------------+
# |   stn|year| mo| da|temp|temp_to_temp(temp, F, K)|
# +------+----+---+---+----+------------------------+
# |359250|2010| 03| 16|38.4|       276.7055555555555|
# |725745|2010| 08| 16|64.4|                  291.15|
# |386130|2010| 01| 24|42.4|      278.92777777777775|
# |386130|2010| 03| 21|34.0|      274.26111111111106|
# |386130|2010| 09| 18|54.1|      285.42777777777775|
# +------+----+---+---+----+------------------------+
# only showing top 5 rows

# end::sol8_3[]

# tag::sol8_4[]
@F.udf(T.DoubleType())
def naive_udf(value: float) -> float:
    return value * 3.14159


# end::sol8_4[]

# tag::exo8_5[]
from fractions import Fraction
from typing import Tuple, Optional

Frac = Tuple[int, int]


def py_reduce_fraction(frac: Frac) -> Optional[Frac]:
    """Reduce a fraction represented as a 2-tuple of integers."""
    num, denom = frac
    if denom:
        answer = Fraction(num, denom)
        return answer.numerator, answer.denominator
    return None


SparkFrac = T.ArrayType(T.LongType())

reduce_fraction = F.udf(py_reduce_fraction, SparkFrac)

fractions = [[x, y] for x in range(100) for y in range(1, 100)]

test_frac = spark.createDataFrame(fractions, ["numerator", "denominator"])

test_frac = test_frac.select(
    F.array(F.col("numerator"), F.col("denominator")).alias("fraction"),
).withColumn(
    "reduced_fraction", reduce_fraction(F.col("fraction"))
)

# end::exo8_5[]

# tag::sol8_5[]

@F.udf(SparkFrac)
def add_fractions(left: Frac, right: Frac) -> Optional[Frac]:
    left_num, left_denom = left
    right_num, right_denom = right
    if left_denom and right_denom:  # avoid division by zero
        answer = Fraction(left_num, left_denom) + Fraction(right_num, right_denom)
        return answer.numerator, answer.denominator
    return None

test_frac.withColumn("sum_frac", add_fractions("reduced_fraction", "reduced_fraction")).show(5)
# +--------+----------------+--------+
# |fraction|reduced_fraction|sum_frac|
# +--------+----------------+--------+
# |  [0, 1]|          [0, 1]|  [0, 1]|
# |  [0, 2]|          [0, 1]|  [0, 1]|
# |  [0, 3]|          [0, 1]|  [0, 1]|
# |  [0, 4]|          [0, 1]|  [0, 1]|
# |  [0, 5]|          [0, 1]|  [0, 1]|
# +--------+----------------+--------+
# only showing top 5 rows
# end::sol8_5[]

# tag::sol8_6[]
def py_reduce_fraction(frac: Frac) -> Optional[Frac]:
    """Reduce a fraction represented as a 2-tuple of integers."""
    MAX_LONG = pow(2, 63) - 1
    MIN_LONG = -pow(2, 63)
    num, denom = frac
    if not denom:
        return None
    left, right = Fraction(num, denom).as_integer_ratio()
    if left > MAX_LONG or right > MAX_LONG or left < MIN_LONG or right < MIN_LONG:
        return None
    return left, right
# end::sol8_6[]
