# /usr/bin/env python3

# pylint: disable=C0413,C0411,C0116

"""Code inside the chapter for the book PySpark in Action (chapter 8)."""

# tag::ch08-simple-rdd[]

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

collection = [1, "two", 3.0, ("four", 4), {"five": 5}]  # <1>

sc = spark.sparkContext  # <2>

collection_rdd = sc.parallelize(collection)  # <3>

print(collection_rdd)
# ParallelCollectionRDD[0] at parallelize at PythonRDD.scala:195  <4>

# end::ch08-simple-rdd[]

# tag::ch08-rdd-map[]
from py4j.protocol import Py4JJavaError


def add_one(value):
    return value + 1  # <1>


collection_rdd = collection_rdd.map(add_one)  # <2>

try:
    print(collection_rdd.collect())  # <3>
except Py4JJavaError:
    pass

# Stack trace galore! The important bit, you'll get one of the following:
# TypeError: can only concatenate str (not "int") to str
# TypeError: unsupported operand type(s) for +: 'dict' and 'int'
# TypeError: can only concatenate tuple (not "int") to tuple

# end::ch08-rdd-map[]

# tag::ch08-rdd-map2[]

collection_rdd = sc.parallelize(collection)  # <1>


def safer_add_one(value):
    try:
        return value + 1
    except TypeError:
        return value  # <2>


collection_rdd = collection_rdd.map(safer_add_one)

print(collection_rdd.collect())
# [2, 'two', 4.0, ('four', 4), {'five': 5}] <3>

# end::ch08-rdd-map2[]

# tag::ch08-rdd-filter[]

collection_rdd = collection_rdd.filter(
    lambda elem: isinstance(elem, (float, int))
)

print(collection_rdd.collect())
# [2, 4.0]

# end::ch08-rdd-filter[]

# tag::ch08-rdd-reduce[]

from operator import add  # <1>

collection_rdd = sc.parallelize([4, 7, 9, 1, 3])

print(collection_rdd.reduce(add))  # 24
# end::ch08-rdd-reduce[]

# tag::ch08-df-to-rdd[]

df = spark.createDataFrame([[1], [2], [3]], schema=["column"])

print(df.rdd)
# MapPartitionsRDD[22] at javaToPython at NativeMethodAccessorImpl.java:0

print(df.rdd.collect())
# [Row(column=1), Row(column=2), Row(column=3)]

# end::ch08-df-to-rdd[]

# tag::ch08-fraction-df[]
import pyspark.sql.functions as F
import pyspark.sql.types as T

fractions = [[x, y] for x in range(100) for y in range(1, 100)]  # <1>

frac_df = spark.createDataFrame(fractions, ["numerator", "denominator"])

frac_df = frac_df.select(
    F.array(F.col("numerator"), F.col("denominator")).alias(
        "fraction"
    ),  # <2>
)

frac_df.show(5, False)
# +--------+
# |fraction|
# +--------+
# |[0, 1]  |
# |[0, 2]  |
# |[0, 3]  |
# |[0, 4]  |
# |[0, 5]  |
# +--------+
# only showing top 5 rows
# end::ch08-fraction-df[]

# tag::ch08-udf-python[]

from fractions import Fraction  # <1>
from typing import Tuple, Optional  # <2>

Frac = Tuple[int, int]  # <3>


def py_reduce_fraction(frac: Frac) -> Optional[Frac]:  # <4>
    """Reduce a fraction represented as a 2-tuple of integers."""
    num, denom = frac
    if denom:
        answer = Fraction(num, denom)
        return answer.numerator, answer.denominator
    return None


assert py_reduce_fraction((3, 6)) == (1, 2)  # <5>
assert py_reduce_fraction((1, 0)) is None


def py_fraction_to_float(frac: Frac) -> Optional[float]:
    """Transforms a fraction represented as a 2-tuple of integers into a float."""
    num, denom = frac
    if denom:
        return num / denom
    return None


assert py_fraction_to_float((2, 8)) == 0.25
assert py_fraction_to_float((10, 0)) is None
# end::ch08-udf-python[]

# tag::ch08-udf1[]
SparkFrac = T.ArrayType(T.LongType())  # <1>

reduce_fraction = F.udf(py_reduce_fraction, SparkFrac)  # <2>

frac_df = frac_df.withColumn(
    "reduced_fraction", reduce_fraction(F.col("fraction"))  # <3>
)

frac_df.show(5, False)
# +--------+----------------+
# |fraction|reduced_fraction|
# +--------+----------------+
# |[0, 1]  |[0, 1]          |
# |[0, 2]  |[0, 1]          |
# |[0, 3]  |[0, 1]          |
# |[0, 4]  |[0, 1]          |
# |[0, 5]  |[0, 1]          |
# +--------+----------------+
# only showing top 5 rows

# end::ch08-udf1[]

# tag::ch08-udf2[]


@F.udf(T.DoubleType())  # <1>
def fraction_to_float(frac: Frac) -> Optional[float]:
    """Transforms a fraction represented as a 2-tuple of integers into a float."""
    num, denom = frac
    if denom:
        return num / denom
    return None


frac_df = frac_df.withColumn(
    "fraction_float", fraction_to_float(F.col("reduced_fraction"))
)

frac_df.select("reduced_fraction", "fraction_float").distinct().show(
    5, False
)
# +----------------+-------------------+
# |reduced_fraction|fraction_float     |
# +----------------+-------------------+
# |[3, 50]         |0.06               |
# |[3, 67]         |0.04477611940298507|
# |[7, 76]         |0.09210526315789473|
# |[9, 23]         |0.391304347826087  |
# |[9, 25]         |0.36               |
# +----------------+-------------------+
# only showing top 5 rows
assert fraction_to_float.func((1, 2)) == 0.5  # <2>

# end::ch08-udf2[]
