#!/usr/bin/env python3

# pylint: disable=wrong-import-position
# pylint: disable=wrong-import-order
# pylint: disable=ungrouped-imports
# pylint: disable=missing-function-docstring


# tag::ch12-import-and-setup-spark[]
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import pyspark.sql.types as T

spark = (
    SparkSession.builder.appName("Recipes ML model - Are you a dessert?")
    .config("spark.driver.memory", "8g")
    .getOrCreate()
)

# end::ch12-import-and-setup-spark[]

# tag::ch12-data-ingestion-and-description[]
food = spark.read.csv(
    "./data/recipes/epi_r.csv", inferSchema=True, header=True
)

print(food.count(), len(food.columns))  # <1>
# 20057 680  # <1>

food.printSchema()
# root
#  |-- title: string (nullable = true)
#  |-- rating: string (nullable = true)
#  |-- calories: string (nullable = true)
#  |-- protein: double (nullable = true)
#  |-- fat: double (nullable = true)
#  |-- sodium: double (nullable = true)
#  |-- #cakeweek: double (nullable = true)
#  |-- #wasteless: double (nullable = true) <2>
#  |-- 22-minute meals: double (nullable = true) <3>
#  |-- 3-ingredient recipes: double (nullable = true)
#  |-- 30 days of groceries: double (nullable = true)
#  ...
#  |-- créme de cacao: double (nullable = true)
#  |-- crêpe: double (nullable = true)
#  |-- cr��me de cacao: double (nullable = true) <4>
# ... and many more columns

# end::ch12-data-ingestion-and-description[]


# tag::ch12-sanitize-columns[]
def sanitize_column_name(name):
    """Drops unwanted characters from the column name.

    We replace spaces, dashes and slashes with underscore,
    and only keep alphanumeric characters."""
    answer = name
    for i, j in ((" ", "_"), ("-", "_"), ("/", "_"), ("&", "and")):  # <1>
        answer = answer.replace(i, j)
    return "".join(
        [
            char
            for char in answer
            if char.isalpha() or char.isdigit() or char == "_"  # <2>
        ]
    )


food = food.toDF(*[sanitize_column_name(name) for name in food.columns])
# end::ch12-sanitize-columns[]


# tag::ch12-summary-columns[]
# for x in food.columns:
#     food.select(x).summary().show()

# many tables looking like this one.
# +-------+--------------------+
# |summary|               clove|
# +-------+--------------------+
# |  count|               20052|  <1>
# |   mean|0.009624975064831438|
# | stddev| 0.09763611178399834|
# |    min|                 0.0|  <2>
# |    25%|                 0.0|  <2>
# |    50%|                 0.0|  <2>
# |    75%|                 0.0|  <2>
# |    max|                 1.0|  <2>
# +-------+--------------------+
#
# end::ch12-summary-columns[]


# tag::ch12-binary-columns[]
import pandas as pd

pd.set_option("display.max_rows", 1000)  # <1>

is_binary = food.agg(
    *[
        (F.size(F.collect_set(x)) == 2).alias(x)  # <2>
        for x in food.columns
    ]
).toPandas()

is_binary.unstack()  # <3>

# title                     0    False
# rating                    0    False
# calories                  0    False
# protein                   0    False
# fat                       0    False
# sodium                    0    False
# cakeweek                  0    False
# wasteless                 0    False
# 22_minute_meals           0     True
# 3_ingredient_recipes      0     True
# ... the rest are all = True
# end::ch12-binary-columns[]

print(is_binary.unstack()[~is_binary.unstack()])


# tag::ch12-more-binary-columns[]
food.agg(*[F.collect_set(x) for x in ("cakeweek", "wasteless")]).show(
    1, False
)

# +-------------------------------+----------------------+
# |collect_set(cakeweek)          |collect_set(wasteless)|
# +-------------------------------+----------------------+
# |[0.0, 1.0, 1188.0, 24.0, 880.0]|[0.0, 1.0, 1439.0]    |
# +-------------------------------+----------------------+

food.where("cakeweek > 1.0 or wasteless > 1.0").select(
    "title", "rating", "wasteless", "cakeweek", food.columns[-1]  # <1>
).show()

# +--------------------+--------------------+---------+--------+------+
# |               title|              rating|wasteless|cakeweek|turkey|
# +--------------------+--------------------+---------+--------+------+
# |"Beet Ravioli wit...| Aged Balsamic Vi...|      0.0|   880.0|   0.0|
# |"Seafood ""Catapl...|            Vermouth|   1439.0|    24.0|   0.0|
# |"""Pot Roast"" of...| Aunt Gloria-Style "|      0.0|  1188.0|   0.0|
# +--------------------+--------------------+---------+--------+------+
# end::ch12-more-binary-columns[]


# tag::ch12-removing-bad-records[]
food = food.where(
    (
        F.col("cakeweek").isin([0.0, 1.0])  # <1>
        | F.col("cakeweek").isNull()  # <1>
    )
    & (
        F.col("wasteless").isin([0.0, 1.0])  # <1>
        | F.col("wasteless").isNull()  # <1>
    )
)

print(food.count(), len(food.columns))

# 20054 680  <2>
# end::ch12-removing-bad-records[]


# tag::ch12-column-classification[]
IDENTIFIERS = ["title"]

CONTINUOUS_COLUMNS = [
    "rating",
    "calories",
    "protein",
    "fat",
    "sodium",
]

TARGET_COLUMN = ["dessert"]  # <1>

BINARY_COLUMNS = [
    x
    for x in food.columns
    if x not in CONTINUOUS_COLUMNS
    and x not in TARGET_COLUMN
    and x not in IDENTIFIERS
]
# end::ch12-column-classification[]


# tag::ch12-drop-null-records[]
food = food.dropna(
    how="all",
    subset=[x for x in food.columns if x not in IDENTIFIERS],  # <1>
)

food = food.dropna(subset=TARGET_COLUMN)  # <1>

print(food.count(), len(food.columns))
# 20049 680  <2>
# end::ch12-drop-null-records[]


# tag::ch12-null-binary[]
food = food.fillna(0.0, subset=BINARY_COLUMNS)

print(food.where(F.col(BINARY_COLUMNS[0]).isNull()).count())  # => 0
# end::ch12-null-binary[]


# tag::ch12-casting-clean-up[]
from typing import Optional


@F.udf(T.BooleanType())
def is_a_number(value: Optional[str]) -> bool:
    if not value:
        return True
    try:
        _ = float(value)  # <1>
    except ValueError:
        return False
    return True


food.where(~is_a_number(F.col("rating"))).select(
    *CONTINUOUS_COLUMNS
).show()

# +---------+------------+-------+----+------+
# |   rating|    calories|protein| fat|sodium|
# +---------+------------+-------+----+------+
# | Cucumber| and Lemon "|   3.75|null|  null|  <2>
# +---------+------------+-------+----+------+
# end::ch12-casting-clean-up[]


# tag::ch12-continuous-casting[]
for column in ["rating", "calories"]:
    food = food.where(is_a_number(F.col(column)))
    food = food.withColumn(column, F.col(column).cast(T.DoubleType()))

print(food.count(), len(food.columns))

# 20048 680  <1>
# end::ch12-continuous-casting[]


# tag::ch12-continuous-summary[]
food.select(*CONTINUOUS_COLUMNS).summary(
    "mean",
    "stddev",
    "min",
    "1%",
    "5%",
    "50%",
    "95%",
    "99%",
    "max",
).show()

# +-------+------------------+------------------+------------------+
# |summary|            rating|          calories|           protein|
# +-------+------------------+------------------+------------------+
# |   mean| 3.714460295291301|6324.0634571930705|100.17385283565179|
# | stddev|1.3409187660508959|359079.83696340164|3840.6809971287403|
# |    min|               0.0|               0.0|               0.0|
# |     1%|               0.0|              18.0|               0.0|
# |     5%|               0.0|              62.0|               0.0|
# |    50%|             4.375|             331.0|               8.0|
# |    95%|               5.0|            1318.0|              75.0|
# |    99%|               5.0|            3203.0|             173.0|
# |    max|               5.0|       3.0111218E7|          236489.0|
# +-------+------------------+------------------+------------------+

# +-------+-----------------+-----------------+
# |summary|              fat|           sodium|
# +-------+-----------------+-----------------+
# |   mean|346.9398083953107|6226.927244193346|
# | stddev|20458.04034412409|333349.5680370268|
# |    min|              0.0|              0.0|
# |     1%|              0.0|              1.0|
# |     5%|              0.0|              5.0|
# |    50%|             17.0|            294.0|
# |    95%|             85.0|           2050.0|
# |    99%|            207.0|           5661.0|
# |    max|        1722763.0|       2.767511E7|
# +-------+-----------------+-----------------+
# end::ch12-continuous-summary[]

# tag::ch12-null-average-imputation[]
maximum = {
    "calories": 3203.0,  # <1>
    "protein": 173.0,  # <1>
    "fat": 207.0,  # <1>
    "sodium": 5661.0,  # <1>
}

for k, v in maximum.items():
    food = food.withColumn(
        k,
        F.when(F.isnull(F.col(k)), F.col(k)).otherwise(  # <2>
            F.least(F.col(k), F.lit(v))
        ),
    )
# end::ch12-null-average-imputation[]


# tag::ch12-rare-binary[]
inst_sum_of_binary_columns = [
    F.sum(F.col(x)).alias(x) for x in BINARY_COLUMNS
]

sum_of_binary_columns = (
    food.select(*inst_sum_of_binary_columns).head().asDict()  # <1>
)

num_rows = food.count()
too_rare_features = [
    k
    for k, v in sum_of_binary_columns.items()
    if v < 10 or v > (num_rows - 10)
]

len(too_rare_features)  # => 167

print(too_rare_features)
# ['cakeweek', 'wasteless', '30_days_of_groceries',
#  [...]
#  'yuca', 'cookbooks', 'leftovers']

BINARY_COLUMNS = list(set(BINARY_COLUMNS) - set(too_rare_features))  # <2>
# end::ch12-rare-binary[]


# tag::ch12-feature-engineering[]

food = food.withColumn(
    "protein_ratio", F.col("protein") * 4 / F.col("calories")  # <1>
).withColumn(
    "fat_ratio", F.col("fat") * 9 / F.col("calories")
)  # <1>

food = food.fillna(0.0, subset=["protein_ratio", "fat_ratio"])

CONTINUOUS_COLUMNS += ["protein_ratio", "fat_ratio"]  # <2>


# end::ch12-feature-engineering[]

# This covers the division by 0.

# tag::ch12-vector-columns[]

from pyspark.ml.feature import VectorAssembler

continuous_features = VectorAssembler(
    inputCols=CONTINUOUS_COLUMNS, outputCol="continuous_features"
)

vector_food = food.select(CONTINUOUS_COLUMNS)
for x in CONTINUOUS_COLUMNS:
    vector_food = vector_food.where(~F.isnull(F.col(x)))  # <1>

vector_variable = continuous_features.transform(vector_food)

vector_variable.select("continuous_features").show(3, False)

# +---------------------------------------------------------------------+
# |continuous_features                                                  |
# +---------------------------------------------------------------------+
# |[2.5,426.0,30.0,7.0,559.0,0.28169014084507044,0.14788732394366197]   |
# |[4.375,403.0,18.0,23.0,1439.0,0.17866004962779156,0.5136476426799007]|
# |[3.75,165.0,6.0,7.0,165.0,0.14545454545454545,0.38181818181818183]   |
# +---------------------------------------------------------------------+
# only showing top 3 rows

vector_variable.select("continuous_features").printSchema()

# root
#  |-- continuous_features: vector (nullable = true)

# end::ch12-vector-columns[]


# tag::ch12-correlation[]

from pyspark.ml.stat import Correlation

correlation = Correlation.corr(
    vector_variable, "continuous_features"  # <1>
)

correlation.printSchema()

# root
#  |-- pearson(binary_features): matrix (nullable = false)  <2>

correlation_array = correlation.head()[0].toArray()  # <3>

correlation_pd = pd.DataFrame(
    correlation_array,  # <4>
    index=CONTINUOUS_COLUMNS,  # <4>
    columns=CONTINUOUS_COLUMNS,  # <4>
)

print(correlation_pd.iloc[:, :4])

#                  rating  calories   protein       fat # <5>
# rating         1.000000 -0.019631 -0.020484 -0.027028 # <5>
# calories      -0.019631  1.000000  0.958442  0.978012 # <5>
# protein       -0.020484  0.958442  1.000000  0.947768 # <5>
# fat           -0.027028  0.978012  0.947768  1.000000 # <5>
# sodium        -0.032499  0.938167  0.936153  0.914338 # <5>
# protein_ratio -0.026485  0.029879  0.121392  0.086444 # <5>
# fat_ratio     -0.010696 -0.007470  0.000260  0.029411 # <5>

print(correlation_pd.iloc[:, 4:])

#                  sodium  protein_ratio  fat_ratio
# rating        -0.032499      -0.026485  -0.010696
# calories       0.938167       0.029879  -0.007470
# protein        0.936153       0.121392   0.000260
# fat            0.914338       0.086444   0.029411
# sodium         1.000000       0.049268  -0.005783
# protein_ratio  0.049268       1.000000   0.111694
# fat_ratio     -0.005783       0.111694   1.000000

# end::ch12-correlation[]


# tag::ch12-mean-imputation[]
from pyspark.ml.feature import Imputer

OLD_COLS = ["calories", "protein", "fat", "sodium"]
NEW_COLS = ["calories_i", "protein_i", "fat_i", "sodium_i"]

imputer = Imputer(
    strategy="mean",  # <1>
    inputCols=OLD_COLS,  # <2>
    outputCols=NEW_COLS,  # <3>
)

imputer_model = imputer.fit(food)  # <4>

CONTINUOUS_COLUMNS = (
    list(set(CONTINUOUS_COLUMNS) - set(OLD_COLS)) + NEW_COLS  # <5>
)

# end::ch12-mean-imputation[]


# tag::ch12-mean-imputation-application[]

food_imputed = imputer_model.transform(food)

food_imputed.where("calories is null").select("calories", "calories_i").show(
    5, False
)
# +--------+-----------------+
# |calories|calories_i       |
# +--------+-----------------+
# |null    |475.5222194325885| <1>
# |null    |475.5222194325885| <1>
# |null    |475.5222194325885| <1>
# |null    |475.5222194325885| <1>
# |null    |475.5222194325885| <1>
# +--------+-----------------+
# only showing top 5 rows

# end::ch12-mean-imputation-application[]


# tag::ch12-variable-scaling[]

from pyspark.ml.feature import MinMaxScaler

CONTINUOUS_NB = [x for x in CONTINUOUS_COLUMNS if "ratio" not in x]

continuous_assembler = VectorAssembler(
    inputCols=CONTINUOUS_NB, outputCol="continuous"
)

food_features = continuous_assembler.transform(food_imputed)

continuous_scaler = MinMaxScaler(
    inputCol="continuous",
    outputCol="continuous_scaled",
)

food_features = continuous_scaler.fit(food_features).transform(
    food_features
)

food_features.select("continuous_scaled").show(3, False)
# +------------------------------------------------------...+
# |continuous_scaled                                     ...|
# +------------------------------------------------------...+
# |[0.5,0.13300031220730565,0.17341040462427745,0.0338164...|
# |[0.875,0.12581954417733376,0.10404624277456646,0.11111...|
# |[0.75,0.051514205432407124,0.03468208092485549,0.03381...|
# +------------------------------------------------------...+
# only showing top 3 rows

# end::ch12-variable-scaling[]

food_features.write.parquet("./data/food_features.parquet", mode="overwrite")
food.write.parquet("./data/food.parquet", mode="overwrite")
