#!/usr/bin/env python3

from pyspark.ml.feature import MinMaxScaler, VectorAssembler
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder.appName("Recipes ML model - Are you a dessert?")
    .config("spark.driver.memory", "8g")
    .getOrCreate()
)

food = spark.createDataFrame([[1]])

# tag::ch13-transformer-estimator-example[]
CONTINUOUS_NB = ["rating", "calories_i", "protein_i", "fat_i", "sodium_i"]

continuous_assembler = VectorAssembler(
    inputCols=CONTINUOUS_NB, outputCol="continuous"
)

continuous_scaler = MinMaxScaler(
    inputCol="continuous",
    outputCol="continuous_scaled",
)

# end::ch13-transformer-estimator-example[]

# tag::ch13-transformer-direct-param[]
print(continuous_assembler.outputCol)
# VectorAssembler_e18a6589d2d5__outputCol <1>
# end::ch13-transformer-direct-param[]

# tag::ch13-transformer-getter-param[]
print(continuous_assembler.getOutputCol())  # => continuous
# end::ch13-transformer-getter-param[]

# tag::ch13-transformer-explain-param[]
print(continuous_assembler.explainParam("outputCol"))
# outputCol: output column name.  <1>
# (default: VectorAssembler_e18a6589d2d5__output, current: continuous)  <2>
# end::ch13-transformer-explain-param[]

# tag::ch13-transformer-getParam[]
print(continuous_assembler.getParam("outputCol"))
# end::ch13-transformer-getParam[]

# tag::ch13-transformer-setter-param[]
continuous_assembler.setOutputCol("more_continuous")  # <1>

print(continuous_assembler.getOutputCol())  # => more_continuous
# end::ch13-transformer-setter-param[]

# tag::ch13-transformer-setParams[]
continuous_assembler.setParams(
    inputCols=["one", "two", "three"], handleInvalid="skip"
)
print(continuous_assembler.explainParams())
# handleInvalid: How to handle invalid data (NULL and NaN values). [...]
#     (default: error, current: skip)
# inputCols: input column names. (current: ['one', 'two', 'three'])  <1>
# outputCol: output column name.
#     (default: VectorAssembler_e18a6589d2d5__output, current: continuous)
# end::ch13-transformer-setParams[]

# tag::ch13-transformer-clear[]
continuous_assembler.clear(continuous_assembler.handleInvalid)

print(continuous_assembler.getHandleInvalid())  # => error <1>
# end::ch13-transformer-clear[]

# tag::ch13-transformer-inplace[]
new_continuous_assembler = continuous_assembler

new_continuous_assembler.setOutputCol("new_output")

print(new_continuous_assembler.getOutputCol())  # => new_output
print(continuous_assembler.getOutputCol())  # => new_output <1>
# end::ch13-transformer-inplace[]

# tag::ch13-transformer-copy[]
copy_continuous_assembler = continuous_assembler.copy()

copy_continuous_assembler.setOutputCol("copy_output")

print(copy_continuous_assembler.getOutputCol())  # => copy_output
print(continuous_assembler.getOutputCol())  # => new_output <1>
# end::ch13-transformer-copy[]

# tag::ch13-preliminary-pipeline[]
from pyspark.ml import Pipeline
import pyspark.ml.feature as MF

imputer = MF.Imputer(  # <1>
    strategy="mean",
    inputCols=["calories", "protein", "fat", "sodium"],
    outputCols=["calories_i", "protein_i", "fat_i", "sodium_i"],
)

continuous_assembler = MF.VectorAssembler(  # <1>
    inputCols=["rating", "calories_i", "protein_i", "fat_i", "sodium_i"],
    outputCol="continuous",
)

continuous_scaler = MF.MinMaxScaler(  # <1>
    inputCol="continuous",
    outputCol="continuous_scaled",
)

food_pipeline = Pipeline(  # <2>
    stages=[imputer, continuous_assembler, continuous_scaler]
)

# end::ch13-preliminary-pipeline[]

# fmt:off
BINARY_COLUMNS = [ "sausage", "washington_dc", "poppy", "bean", "lemon_juice",
                   "low_sodium", "soy", "sauté", "cornmeal", "lobster", "tangerine", "punch",
                   "gourmet", "cream_cheese", "north_carolina", "dairy_free", "pie",
                   "lima_bean", "maryland", "stuffing_dressing", "tilapia", "vodka", "cheddar",
                   "mustard_greens", "candy_thermometer", "lamb_chop", "whiskey", "pea",
                   "sandwich_theory", "orange", "fall", "quick_and_healthy", "broccoli",
                   "california", "sesame_oil", "gin", "south_carolina", "france", "wasabi",
                   "arugula", "cinco_de_mayo", "campari", "coconut", "low_no_sugar",
                   "no_sugar_added", "low_cal", "back_to_school", "walnut", "prune", "rum",
                   "parsley", "apricot", "garlic", "seed", "rutabaga", "papaya", "fat_free",
                   "graduation", "shellfish", "sage", "low_cholesterol", "olive", "asian_pear",
                   "ohio", "bass", "pennsylvania", "spirit", "white_wine", "halloween",
                   "honey", "healthy", "dill", "tea", "broccoli_rabe", "tamarind",
                   "st_patricks_day", "sherry", "lime_juice", "pork", "chive",
                   "kidney_friendly", "sparkling_wine", "alcoholic", "ground_beef", "wok",
                   "low_fat", "advance_prep_required", "mango", "pine_nut", "pizza",
                   "wild_rice", "butterscotch_caramel", "condiment", "pork_chop", "artichoke",
                   "brown_rice", "green_bean", "sugar_conscious", "stew", "monterey_jack",
                   "avocado", "brine", "grape", "quail", "jerusalem_artichoke", "swiss_cheese",
                   "champagne", "escarole", "pescatarian", "quince", "raw", "florida",
                   "red_wine", "soy_sauce", "trout", "capers", "mothers_day", "mixer", "boil",
                   "new_mexico", "hominy_cornmeal_masa", "parade", "bell_pepper", "gouda",
                   "cabbage", "grill_barbecue", "dinner", "non_alcoholic", "fontina",
                   "pickles", "fish", "tortillas", "cognac_armagnac", "squash", "cashew",
                   "rack_of_lamb", "couscous", "maple_syrup", "leek", "beef_tenderloin",
                   "appetizer", "noodle", "christmas", "pear", "jam_or_jelly", "smoothie",
                   "snack", "créme_de_cacao", "christmas_eve", "bastille_day", "slow_cooker",
                   "sweet_potato_yam", "self", "scallop", "thanksgiving", "oatmeal", "kahlúa",
                   "pittsburgh", "pork_rib", "lettuce", "cardamom", "no_cook", "sour_cream",
                   "chile", "high_fiber", "legume", "steam", "engagement_party", "wedding",
                   "georgia", "zucchini", "cantaloupe", "berry", "quick_and_easy", "quinoa",
                   "date", "curry", "spice", "jalapeño", "radicchio", "halibut", "chill",
                   "grapefruit", "pistachio", "tapioca", "poultry_sausage", "oat",
                   "butternut_squash", "pecan", "watermelon", "fortified_wine", "brandy",
                   "ricotta", "22_minute_meals", "mardi_gras", "los_angeles", "sugar_snap_pea",
                   "tree_nut", "washington", "simmer", "carrot", "frangelico", "paleo",
                   "fourth_of_july", "brie", "ramekin", "parsnip", "pasadena", "meat",
                   "deep_fry", "cocktail", "melon", "sangria", "vegetarian", "bon_appétit",
                   "jícama", "poultry", "buttermilk", "nectarine", "one_pot_meal",
                   "lunar_new_year", "brunch", "weelicious", "hot_pepper", "ice_cream", "egg",
                   "bread", "kiwi", "chickpea", "frozen_dessert", "swordfish", "vermouth",
                   "vegetable", "steak", "brisket", "shavuot", "clam", "asparagus",
                   "tomatillo", "braise", "barley", "new_jersey", "new_years_eve",
                   "peanut_butter", "sandwich", "chocolate", "tree_nut_free", "breadcrumbs",
                   "new_years_day", "backyard_bbq", "onion", "phyllo_puff_pastry_dough",
                   "hot_drink", "cherry", "hazelnut", "cranberry", "calvados", "turkey",
                   "beef", "anise", "drinks", "rosh_hashanah_yom_kippur", "coriander",
                   "blender", "goose", "cheese", "lentil", "tofu", "hanukkah", "eggplant",
                   "vanilla", "shallot", "edible_gift", "cookies", "kwanzaa", "picnic",
                   "lamb_shank", "honeydew", "seattle", "fennel", "feta", "nutmeg",
                   "double_boiler", "macadamia_nut", "molasses", "casserole_gratin", "broil",
                   "diwali", "anniversary", "spinach", "pastry", "margarita", "pomegranate",
                   "oregon", "snapper", "party", "saffron", "pork_tenderloin",
                   "house_and_garden", "potluck", "mint", "roast", "chambord",
                   "kosher_for_passover", "iced_tea", "clove", "prosciutto", "texas", "duck",
                   "potato", "tequila", "collard_greens", "dried_fruit", "mustard", "spring",
                   "blueberry", "ramadan", "virginia", "soup_stew", "food_processor",
                   "breakfast", "oktoberfest", "salad_dressing", "fathers_day", "veal",
                   "stir_fry", "pan_fry", "strawberry", "paprika", "tailgating", "cinnamon",
                   "tomato", "fig", "orange_juice", "soy_free", "butter", "plum",
                   "santa_monica", "crab", "lemon", "oyster", "sauce", "lemongrass",
                   "minnesota", "chestnut", "almond", "3_ingredient_recipes", "peanut_free",
                   "plantain", "cocktail_party", "valentines_day", "nut", "watercress",
                   "super_bowl", "milk_cream", "kid_friendly", "parmesan", "passion_fruit",
                   "pernod", "bacon", "side", "birthday", "ginger", "tropical_fruit",
                   "michigan", "pasta_maker", "wine", "harpercollins", "low_carb", "turnip",
                   "bok_choy", "fry", "anchovy", "chile_pepper", "mushroom", "herb",
                   "chartreuse", "coffee_grinder", "leafy_green", "bulgur", "semolina",
                   "microwave", "root_vegetable", "snack_week", "sake", "bon_apptit",
                   "rosemary", "currant", "tarragon", "raisin", "cookie", "triple_sec",
                   "mandoline", "oscars", "beer", "port", "dip", "easter", "rhubarb",
                   "new_york", "connecticut", "freeze_chill", "ham", "blackberry",
                   "mozzarella", "guava", "bitters", "kirsch", "lime", "grill",
                   "flaming_hot_summer", "passover", "apple", "poker_game_night", "chicken",
                   "hors_doeuvre", "rye", "basil", "cumin", "cake", "kosher", "persimmon",
                   "pumpkin", "shower", "summer", "pepper", "rice", "yellow_squash", "yogurt",
                   "shrimp", "oregano", "wheat_gluten_free", "game", "okra", "endive", "lunch",
                   "cod", "scotch", "winter", "cauliflower", "condiment_spread", "poach",
                   "banana", "pasta", "vegan", "pineapple", "fruit_juice", "lamb", "mussel",
                   "beef_rib", "chard", "kale", "bourbon", "cilantro", "blue_cheese", "beet",
                   "green_onion_scallion", "citrus", "cottage_cheese", "dairy", "whole_wheat",
                   "family_reunion", "tuna", "squid", "purim", "cucumber", "amaretto",
                   "low_sugar", "raspberry", "sukkot", "caraway", "marsala", "portland",
                   "corn", "marinate", "kumquat", "peanut", "brussel_sprout", "peach",
                   "sesame", "massachusetts", "ground_lamb", "salmon", "bake", "aperitif",
                   "radish", "ireland", "liqueur", "ice_cream_machine", "coffee", "colorado",
                   "vinegar", "mayonnaise", "kentucky_derby", "horseradish", "buffet", "candy",
                   "goat_cheese", "drink", "missouri", "beef_shank", "seafood", "celery",
                   "thyme", "salad", "fruit", ]
# fmt:on

food = spark.read.parquet("./data/food.parquet")

# tag::ch13-final-assembly[]

preml_assembler = MF.VectorAssembler(
    inputCols=BINARY_COLUMNS  # <1>
    + ["continuous_scaled"]
    + ["protein_ratio", "fat_ratio"],
    outputCol="features",
)

food_pipeline.setStages(
    [imputer, continuous_assembler, continuous_scaler, preml_assembler]
)

food_pipeline_model = food_pipeline.fit(food)  # <1>
food_features = food_pipeline_model.transform(food)  # <2>
# end::ch13-final-assembly[]


# tag::ch13-sparse-vector[]
food_features.select("title", "dessert", "features").show(5, truncate=30)
# +------------------------------+-------+------------------------------+
# |                         title|dessert|                      features|
# +------------------------------+-------+------------------------------+
# |      Swiss Honey-Walnut Tart |    1.0|(513,[30,47,69,154,214,251,...| <1>
# |Mascarpone Cheesecake with ...|    1.0|(513,[30,47,117,154,181,188...|
# |         Beef and Barley Soup |    0.0|(513,[7,30,44,118,126,140,1...|
# |                     Daiquiri |    0.0|(513,[49,210,214,408,424,50...|
# |Roast Beef and Watercress W...|    0.0|(513,[12,131,161,173,244,25...|
# +------------------------------+-------+------------------------------+
# only showing top 5 rows
# end::ch13-sparse-vector[]

# end::ch12-final-assembly[]

# tag::ch13-metadata[]

print(food_features.schema["features"])

# StructField(features,VectorUDT,true)

print(food_features.schema["features"].metadata)  # <1>
# {
#     "ml_attr": {
#         "attrs": {
#             "numeric": [
#                 {"idx": 0, "name": "sausage"},
#                 {"idx": 1, "name": "washington_dc"},
#                 {"idx": 2, "name": "poppy"},
#                 [...]
#                 {"idx": 510, "name": "continuous_scaled_4"}, <2>
#                 {"idx": 511, "name": "protein_ratio"},
#                 {"idx": 512, "name": "fat_ratio"},
#             ]
#         },
#         "num_attrs": 513,
#     }
# }
# end::ch13-metadata[]


# tag::ch13-logistic-regression[]
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(
    featuresCol="features", labelCol="dessert", predictionCol="prediction"
)

food_pipeline.setStages(
    [
        imputer,
        continuous_assembler,
        continuous_scaler,
        preml_assembler,
        lr,  # <1>
    ]
)
# end::ch13-logistic-regression[]

# tag::ch13-fit-predict[]
train, test = food.randomSplit([0.7, 0.3], 13)  # <1>

train.cache()

food_pipeline_model = food_pipeline.fit(train)  # <2>
results = food_pipeline_model.transform(test)  # <2>
# end::ch13-fit-predict[]


# tag::ch13-show-prediction[]

results.select("prediction", "rawPrediction", "probability").show(3, False)
# +----------+----------------------+--------------------+
# |prediction|rawPrediction         |probability         |
# +----------+----------------------+--------------------+
# |0.0       |[11.98907,-11.9890722]|[0.9999937,6.2116-6]|
# |0.0       |[32.94732,-32.947325] |[0.99999,4.88498-15]|
# |1.0       |[-1.32753,1.32753254] |[0.209567,0.7904]   |
# +----------+----------------------+--------------------+
# end::ch13-show-prediction[]


# tag::ch13-confusion-matrix[]

results.groupby("dessert").pivot("prediction").count().show()
# +-------+----+----+
# |dessert| 0.0| 1.0| <1>
# +-------+----+----+
# |    0.0|4950|  77|
# |    1.0| 104|1005|
# +-------+----+----+
#    <2>

# end::ch13-confusion-matrix[]


# tag::ch13-precision-recall[]
lr_model = food_pipeline_model.stages[-1]  # <1>
metrics = lr_model.evaluate(results.select("title", "dessert", "features"))
# LogisticRegressionTrainingSummary

print(f"Model precision: {metrics.precisionByLabel[1]}")  # <2>
print(f"Model recall: {metrics.recallByLabel[1]}")  # <2>

# Model precision: 0.9288354898336414  # <3>
# Model recall: 0.9062218214607755     # <3>
# end::ch13-precision-recall[]


# tag::ch13-precision-recall-rdd[]
from pyspark.mllib.evaluation import MulticlassMetrics

predictionAndLabel = results.select("prediction", "dessert").rdd

metrics_rdd = MulticlassMetrics(predictionAndLabel)

print(f"Model precision: {metrics_rdd.precision(1.0)}")
print(f"Model recall: {metrics_rdd.recall(1.0)}")

# Model precision: 0.9288354898336414
# Model recall: 0.9062218214607755

# end::ch13-precision-recall-rdd[]


# tag::ch13-binary-evaluator[]
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator(
    labelCol="dessert",  # <1>
    rawPredictionCol="rawPrediction",  # <1>
    metricName="areaUnderROC",
)

accuracy = evaluator.evaluate(results)
print(f"Area under ROC = {accuracy} ")
# Area under ROC = 0.9927442079816466

# end::ch13-binary-evaluator[]


# tag::ch13-roc-curve[]

import matplotlib.pyplot as plt

plt.figure(figsize=(5, 5))
plt.plot([0, 1], [0, 1], "r--")
plt.plot(
    lr_model.summary.roc.select("FPR").collect(),
    lr_model.summary.roc.select("TPR").collect(),
)
plt.xlabel("False positive rate")
plt.ylabel("True positive rate")
plt.show()

# end::ch13-roc-curve[]

# tag::ch13-grid-search[]
from pyspark.ml.tuning import ParamGridBuilder

grid_search = (
    ParamGridBuilder()  # <1>
    .addGrid(lr.elasticNetParam, [0.0, 1.0])  # <2>
    .build()  # <3>
)

print(grid_search)
# [
#     {Param(parent='LogisticRegression_14302c005814',
#            name='elasticNetParam',
#            doc='...'): 0.0},  <4>
#     {Param(parent='LogisticRegression_14302c005814',
#            name='elasticNetParam',
#            doc='...'): 1.0}  <4>
# ]


# end::ch13-grid-search[]


# tag::ch13-cross-validation[]
from pyspark.ml.tuning import CrossValidator

cv = CrossValidator(
    estimator=food_pipeline,
    estimatorParamMaps=grid_search,
    evaluator=evaluator,
    numFolds=3,
    seed=13,
    collectSubModels=True,
)

cv_model = cv.fit(train)  # <1>

print(cv_model.avgMetrics)
# [0.9899971586317382, 0.9899992947698821]  <2>

pipeline_food_model = cv_model.bestModel
# end::ch13-cross-validation[]

# fmt:off
# tag::ch13-feature-coefficients[]
import pandas as pd

feature_names = ["(Intercept)"] + [  # <1>
    x["name"]
    for x in (
        food_features
        .schema["features"]
        .metadata["ml_attr"]["attrs"]["numeric"]
    )
]

feature_coefficients = [lr_model.intercept] + list(  # <1>
    lr_model.coefficients.values
)


coefficients = pd.DataFrame(
    feature_coefficients, index=feature_names, columns=["coef"]
)

coefficients["abs_coef"] = coefficients["coef"].abs()  # <2>

print(coefficients.sort_values(["abs_coef"]))
#                                coef   abs_coef
# kirsch                     0.004305   0.004305
# jam_or_jelly              -0.006601   0.006601
# lemon                     -0.010902   0.010902
# food_processor            -0.018454   0.018454
# phyllo_puff_pastry_dough  -0.020231   0.020231
# ...                             ...        ...
# cauliflower              -13.928099  13.928099
# rye                      -13.987067  13.987067
# plantain                 -15.551487  15.551487
# quick_and_healthy        -15.908631  15.908631
# horseradish              -17.172171  17.172171

# [514 rows x 2 columns]
# end::ch13-feature-coefficients[]
