#!/usr/bin/env python3

# pylint: disable=wrong-import-position
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
# pylint: disable=invalid-name
# pylint: disable=protected-access

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

# tag::ch14-transformer-shell[]
from pyspark.ml import Transformer


class ScalarNAFiller(Transformer):  # <1>
    pass


# end::ch14-transformer-shell[]

# tag::ch14-transformer-test-function[]
import pyspark.sql.functions as F
from pyspark.sql import Column, DataFrame

test_df = spark.createDataFrame(
    [[1, 2, 4, 1], [3, 6, 5, 4], [9, 4, None, 9], [11, 17, None, 3]],
    ["one", "two", "three", "four"],
)


def scalarNAFillerFunction(
    df: DataFrame, inputCol: Column, outputCol: str, filler: float = 0.0
):
    return df.withColumn(outputCol, inputCol).fillna(
        filler, subset=outputCol
    )


scalarNAFillerFunction(test_df, F.col("three"), "five", -99.0).show()
# +---+---+-----+----+----+
# |one|two|three|four|five|
# +---+---+-----+----+----+
# |  1|  2|    4|   1|   4|
# |  3|  6|    5|   4|   5|
# |  9|  4| null|   9| -99| <1>
# | 11| 17| null|   3| -99| <1>
# +---+---+-----+----+----+
# end::ch14-transformer-test-function[]


# tag::ch14-transformer-test-function-nested[]
def scalarNAFillerTransformFunction(
    inputCol: Column, outputCol: str, filler: float = 0.0
):
    def _inner(df):
        return df.withColumn(outputCol, inputCol).fillna(
            filler, subset=outputCol
        )

    return _inner


# end::ch14-transformer-test-function-nested[]


scalarNAFillerTransformFunction(F.col("three"), "five")(test_df)

test_df.transform(scalarNAFillerTransformFunction(F.col("three"), "five"))


# tag::ch14-transformer-filler-param[]
from pyspark.ml.param import Param, Params, TypeConverters

filler = Param(
    Params._dummy(),  # <1>
    "filler",  # <2>
    "Value we want to replace our null values with.",  # <3>
    typeConverter=TypeConverters.toFloat,  # <4>
)

filler
# Param(parent='undefined', name='filler',
#       doc='Value we want to replace our null values with.')
# end::ch14-transformer-filler-param[]


# tag::ch14-transformer-shell-params[]
from pyspark.ml.param.shared import HasInputCol, HasOutputCol


class ScalarNAFiller(Transformer, HasInputCol, HasOutputCol):  # <1>

    filler = Param(  # <2>
        Params._dummy(),
        "filler",
        "Value we want to replace our null values with.",
        typeConverter=TypeConverters.toFloat,
    )

    pass


# end::ch14-transformer-shell-params[]

# tag::ch14-transformer-setParams[]
from pyspark import keyword_only  # <1>


@keyword_only
def setParams(self, *, inputCol=None, outputCol=None, filler=None):  # <2>
    kwargs = self._input_kwargs
    return self._set(**kwargs)  # <3>


# end::ch14-transformer-setParams[]


# tag::ch14-transform-setters[]
def setFiller(self, new_filler):
    return self.setParams(filler=new_filler)  # <1>


def setInputCol(self, new_inputCol):
    return self.setParams(inputCol=new_inputCol)  # <1>


def setOutputCol(self, new_outputCol):
    return self.setParams(outputCol=new_outputCol)  # <1>


# end::ch14-transform-setters[]

# tag::ch14-transformer-getter[]
def getFiller(self):
    return self.getOrDefault(self.filler)


# end::ch14-transformer-getter[]


# tag::ch14-transformer-shell-getset[]
class ScalarNAFiller(Transformer, HasInputCol, HasOutputCol):

    filler = [...]  # elided for terseness

    @keyword_only
    def setParams(self, inputCol=None, outputCol=None, filler=None):
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setFiller(self, new_filler):
        return self.setParams(filler=new_filler)

    def getFiller(self):
        return self.getOrDefault(self.filler)

    def setInputCol(self, new_inputCol):
        return self.setParams(inputCol=new_inputCol)

    def setOutputCol(self, new_outputCol):
        return self.setParams(outputCol=new_outputCol)


# end::ch14-transformer-shell-getset[]


# tag::ch14-transformer-init[]
class ScalarNAFiller(Transformer, HasInputCol, HasOutputCol):
    @keyword_only
    def __init__(self, inputCol=None, outputCol=None, filler=None):
        super().__init__()  # <1>
        self._setDefault(filler=None)  # <2>
        kwargs = self._input_kwargs
        self.setParams(**kwargs)  # <3>

    @keyword_only
    def setParams(self, *, inputCol=None, outputCol=None, filler=None):
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    # Rest of the methods


# end::ch14-transformer-init[]

# tag::ch14-transformer-transform[]


def _transform(self, dataset):

    if not self.isSet("inputCol"):
        raise ValueError(  # <1>
            "No input column set for the ScalarNAFiller transformer."
        )
    input_column = dataset[self.getInputCol()]
    output_column = self.getOutputCol()
    na_filler = self.getFiller()
    return dataset.withColumn(
        output_column, input_column.cast("double")
    ).fillna(na_filler, output_column)


# end::ch14-transformer-transform[]


# tag::ch14-transformer-complete[]


class ScalarNAFiller(Transformer, HasInputCol, HasOutputCol):

    filler = Param(  # <1>
        Params._dummy(),
        "filler",
        "Value we want to replace our null values with.",
        typeConverter=TypeConverters.toFloat,
    )

    @keyword_only
    def __init__(self, inputCol=None, outputCol=None, filler=None):  # <2>
        super().__init__()
        self._setDefault(filler=None)
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(self, inputCol=None, outputCol=None, filler=None):  # <3>
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setFiller(self, new_filler):  # <4>
        return self.setParams(filler=new_filler)

    def setInputCol(self, new_inputCol):  # <4>
        return self.setParams(inputCol=new_inputCol)

    def setOutputCol(self, new_outputCol):  # <4>
        return self.setParams(outputCol=new_outputCol)

    def getFiller(self):  # <5>
        return self.getOrDefault(self.filler)

    def _transform(self, dataset):  # <6>
        if not self.isSet("inputCol"):
            raise ValueError(
                "No input column set for the "
                "ScalarNAFiller transformer."
            )
        input_column = dataset[self.getInputCol()]
        output_column = self.getOutputCol()
        na_filler = self.getFiller()
        return dataset.withColumn(
            output_column, input_column.cast("double")
        ).fillna(na_filler, output_column)


# end::ch14-transformer-complete[]

# tag::ch14-transformer-test[]
test_ScalarNAFiller = ScalarNAFiller(
    inputCol="three", outputCol="five", filler=-99
)

test_ScalarNAFiller.transform(test_df).show()
# +---+---+-----+----+-----+
# |one|two|three|four| five|
# +---+---+-----+----+-----+
# |  1|  2|    4|   1|  4.0|
# |  3|  6|    5|   4|  5.0|
# |  9|  4| null|   9|-99.0|
# | 11| 17| null|   3|-99.0|
# +---+---+-----+----+-----+
# end::ch14-transformer-test[]

# tag::ch14-transformer-test2[]

test_ScalarNAFiller.setFiller(17).transform(test_df).show()  # <1>

test_ScalarNAFiller.transform(
    test_df, params={test_ScalarNAFiller.filler: 17}  # <2>
).show()
# +---+---+-----+----+----+
# |one|two|three|four|five|
# +---+---+-----+----+----+
# |  1|  2|    4|   1| 4.0|
# |  3|  6|    5|   4| 5.0|
# |  9|  4| null|   9|17.0|
# | 11| 17| null|   3|17.0|
# +---+---+-----+----+----+


# end::ch14-transformer-test2[]


# tag::ch14-estimator-test-function[]


def test_ExtremeValueCapperModel_transform(
    df: DataFrame,
    inputCol: Column,
    outputCol: str,
    cap: float,
    floor: float,
):
    return df.withColumn(
        outputCol,
        F.when(inputCol > cap, cap)  # <1>
        .when(inputCol < floor, floor)  # <1>
        .otherwise(inputCol),  # <1>
    )


def test_ExtremeValueCapper_fit(
    df: DataFrame, inputCol: Column, outputCol: str, boundary: float
):
    avg, stddev = df.agg(
        F.mean(inputCol), F.stddev(inputCol)
    ).head()  # <2>
    cap = avg + boundary * stddev  # <3>
    floor = avg - boundary * stddev  # <3>
    return test_ExtremeValueCapperModel_transform(  # <4>
        df, inputCol, outputCol, cap, floor  # <4>
    )


# end::ch14-estimator-test-function[]


# tag::ch14-estimator-mixin[]


class _ExtremeValueCapperParams(HasInputCol, HasOutputCol):

    boundary = Param(
        Params._dummy(),
        "boundary",
        "Multiple of standard deviation for the cap and floor. Default = 0.0.",
        TypeConverters.toFloat,
    )

    def __init__(self, *args):
        super().__init__(*args)  # <1>
        self._setDefault(boundary=0.0)  # <2>

    def getBoundary(self):  # <3>
        return self.getOrDefault(self.boundary)


# end::ch14-estimator-mixin[]


# tag::ch14-estimator-model[]

from pyspark.ml import Model


class ExtremeValueCapperModel(Model, _ExtremeValueCapperParams):  # <1>

    cap = Param(
        Params._dummy(),
        "cap",
        "Upper bound of the values `inputCol` can take."
        "Values will be capped to this value.",
        TypeConverters.toFloat,
    )
    floor = Param(
        Params._dummy(),
        "floor",
        "Lower bound of the values `inputCol` can take."
        "Values will be floored to this value.",
        TypeConverters.toFloat,
    )

    @keyword_only
    def __init__(
        self, inputCol=None, outputCol=None, cap=None, floor=None
    ):
        super().__init__()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    def _transform(self, dataset):
        if not self.isSet("inputCol"):
            raise ValueError(
                "No input column set for the "
                "ExtremeValueCapperModel transformer."
            )
        input_column = dataset[self.getInputCol()]
        output_column = self.getOutputCol()
        cap_value = self.getOrDefault("cap")
        floor_value = self.getOrDefault("floor")

        return dataset.withColumn(
            output_column,
            F.when(input_column > cap_value, cap_value).when(input_column < floor_value, floor_value).otherwise(input_column),
        )


# end::ch14-estimator-model[]

# Full class defined here
class ExtremeValueCapperModel(Model, _ExtremeValueCapperParams):  # <1>

    cap = Param(
        Params._dummy(),
        "cap",
        "Upper bound of the values `inputCol` can take."
        "Values will be capped to this value.",
        TypeConverters.toFloat,
    )
    floor = Param(
        Params._dummy(),
        "floor",
        "Lower bound of the values `inputCol` can take."
        "Values will be floored to this value.",
        TypeConverters.toFloat,
    )

    @keyword_only
    def __init__(
        self, inputCol=None, outputCol=None, cap=None, floor=None
    ):
        super().__init__()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(
        self, inputCol=None, outputCol=None, cap=None, floor=None
    ):
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setCap(self, new_cap):
        return self.setParams(cap=new_cap)

    def setFloor(self, new_floor):
        return self.setParams(floor=new_floor)

    def setInputCol(self, new_inputCol):
        return self.setParams(inputCol=new_inputCol)

    def setOutputCol(self, new_outputCol):
        return self.setParams(outputCol=new_outputCol)

    def getCap(self):
        return self.getOrDefault(self.cap)

    def getFloor(self):
        return self.getOrDefault(self.floor)

    def _transform(self, dataset):
        if not self.isSet("inputCol"):
            raise ValueError(
                "No input column set for the "
                "ExtremeValueCapperModel transformer."
            )
        input_column = dataset[self.getInputCol()]
        output_column = self.getOutputCol()
        cap_value = self.getOrDefault("cap")
        floor_value = self.getOrDefault("floor")

        return dataset.withColumn(
            output_column,
            F.when(input_column > cap_value, cap_value)\
             .when(input_column < floor_value, floor_value)\
             .otherwise(input_column),
        )


# tag::ch14-estimator-fit[]
from pyspark.ml import Estimator


class ExtremeValueCapper(Estimator, _ExtremeValueCapperParams):  # <1>

    # [... __init__(), setters definition]

    def _fit(self, dataset):
        input_column = self.getInputCol()  # <2>
        output_column = self.getOutputCol()  # <2>
        boundary = self.getBoundary()  # <2>

        avg, stddev = dataset.agg(  # <3>
            F.mean(input_column), F.stddev(input_column)
        ).head()

        cap_value = avg + boundary * stddev
        floor_value = avg - boundary * stddev
        return ExtremeValueCapperModel(  # <4>
            inputCol=input_column,  # <4>
            outputCol=output_column,  # <4>
            cap=cap_value,  # <4>
            floor=floor_value,  # <4>
        )


# end::ch14-estimator-fit[]


# Full estimator code


class ExtremeValueCapper(Estimator, _ExtremeValueCapperParams):
    @keyword_only
    def __init__(self, inputCol=None, outputCol=None, boundary=None):
        super().__init__()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(self, inputCol=None, outputCol=None, boundary=None):
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setBoundary(self, new_boundary):
        self.setParams(boundary=new_boundary)

    def setInputCol(self, new_inputCol):
        return self.setParams(inputCol=new_inputCol)

    def setOutputCol(self, new_outputCol):
        return self.setParams(outputCol=new_outputCol)

    def _fit(self, dataset):
        input_column = self.getInputCol()
        output_column = self.getOutputCol()
        boundary = self.getBoundary()

        avg, stddev = dataset.agg(
            F.mean(input_column), F.stddev(input_column)
        ).head()

        cap_value = avg + boundary * stddev
        floor_value = avg - boundary * stddev
        return ExtremeValueCapperModel(
            inputCol=input_column,
            outputCol=output_column,
            cap=cap_value,
            floor=floor_value,
        )


# tag::ch14-estimator-test[]

test_EVC = ExtremeValueCapper(
    inputCol="one", outputCol="five", boundary=1.0
)

test_EVC.fit(test_df).transform(test_df).show()  # <1>
# +---+---+-----+----+------------------+
# |one|two|three|four|              five|
# +---+---+-----+----+------------------+
# |  1|  2|    4|   1|1.2390477143047667| <2>
# |  3|  6|    5|   4|               3.0|
# |  9|  4| null|   9|               9.0|
# | 11| 17| null|   3|10.760952285695232| <2>
# +---+---+-----+----+------------------+

# end::ch14-estimator-test[]

# tag::ch14-transformer-many-columns[]
from pyspark.ml.param.shared import HasInputCols, HasOutputCols


class ScalarNAFiller(
    Transformer,
    HasInputCol,
    HasOutputCol,
    HasInputCols,  # <1>
    HasOutputCols,  # <1>
):
    pass


# end::ch14-transformer-many-columns[]


# tag::ch14-transformer-check-params[]
def checkParams(self):

    if self.isSet("inputCol") and (self.isSet("inputCols")):  # <1>
        raise ValueError(
            "Only one of `inputCol` or `inputCols`" "must be set."
        )

    if not (self.isSet("inputCol") or self.isSet("inputCols")):  # <2>
        raise ValueError("One of `inputCol` or `inputCols` must be set.")

    if self.isSet("inputCols"):
        if len(self.getInputCols()) != len(self.getOutputCols()):  # <3>
            raise ValueError(
                "The length of `inputCols` does not match"
                " the length of `outputCols`"
            )


# end::ch14-transformer-check-params[]

# tag::ch14-transformer-mult-transform[]
def _transform(self, dataset):
    self.checkParams()  # <1>

    input_columns = (  # <2>
        [self.getInputCol()]
        if self.isSet("inputCol")
        else self.getInputCols()
    )
    output_columns = (  # <2>
        [self.getOutputCol()]
        if self.isSet("outputCol")
        else self.getOutputCols()
    )

    answer = dataset

    if input_columns != output_columns:  # <3>
        for in_col, out_col in zip(input_columns, output_columns):
            answer = answer.withColumn(out_col, F.col(in_col))

    na_filler = self.getFiller()
    return dataset.fillna(na_filler, output_columns)


# end::ch14-transformer-mult-transform[]


class ScalarNAFiller(
    Transformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols
):
    """Fills the `null` values of inputCol with a scalar value `filler`."""

    filler = Param(
        Params._dummy(),
        "filler",
        "Value we want to replace our null values with.",
        typeConverter=TypeConverters.toFloat,
    )

    @keyword_only
    def __init__(
        self,
        inputCol=None,
        outputCol=None,
        inputCols=None,
        outputCols=None,
        filler=None,
    ):
        super().__init__()
        self._setDefault(filler=None)
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(
        self,
        inputCol=None,
        outputCol=None,
        inputCols=None,
        outputCols=None,
        filler=None,
    ):
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setFiller(self, new_filler):
        return self.setParams(filler=new_filler)

    def setInputCol(self, new_inputCol):
        return self.setParams(inputCol=new_inputCol)

    def setOutputCol(self, new_outputCol):
        return self.setParams(outputCol=new_outputCol)

    def setInputCols(self, new_inputCols):
        return self.setParams(inputCols=new_inputCols)

    def setOutputCols(self, new_outputCols):
        return self.setParams(outputCols=new_outputCols)

    def getFiller(self):
        return self.getOrDefault(self.filler)

    def checkParams(self):
        # Test #1: either inputCol or inputCols can be set (but not both).
        if self.isSet("inputCol") and (self.isSet("inputCols")):
            raise ValueError(
                "Only one of `inputCol` or `inputCols`" "must be set."
            )

        # Test #2: at least one of inputCol or inputCols must be set.
        if not (self.isSet("inputCol") or self.isSet("inputCols")):
            raise ValueError(
                "One of `inputCol` or `inputCols` must be set."
            )

        # Test #3: if `inputCols` is set, then `outputCols`
        # must be a list of the same len()
        if self.isSet("inputCols"):
            if len(self.getInputCols()) != len(self.getOutputCols()):
                raise ValueError(
                    "The length of `inputCols` does not match"
                    " the length of `outputCols`"
                )

    def _transform(self, dataset):
        self.checkParams()

        # If `inputCol` / `outputCol`, we wrap into a single-item list
        input_columns = (
            [self.getInputCol()]
            if self.isSet("inputCol")
            else self.getInputCols()
        )
        output_columns = (
            [self.getOutputCol()]
            if self.isSet("outputCol")
            else self.getOutputCols()
        )

        answer = dataset

        # If input_columns == output_columns, we overwrite and no need to create
        # new columns.
        if input_columns != output_columns:
            for in_col, out_col in zip(input_columns, output_columns):
                answer = answer.withColumns(out_col, F.col(in_col))

        na_filler = self.getFiller()
        return dataset.fillna(na_filler, output_columns)


# tag::ch14-custom-tr-es[]
scalar_na_filler = ScalarNAFiller(
    inputCols=BINARY_COLUMNS, outputCols=BINARY_COLUMNS, filler=0.0  # <1>
)
extreme_value_capper_cal = ExtremeValueCapper(
    inputCol="calories", outputCol="calories", boundary=2.0
)
extreme_value_capper_pro = ExtremeValueCapper(
    inputCol="protein", outputCol="protein", boundary=2.0
)
extreme_value_capper_fat = ExtremeValueCapper(
    inputCol="fat", outputCol="fat", boundary=2.0
)
extreme_value_capper_sod = ExtremeValueCapper(
    inputCol="sodium", outputCol="sodium", boundary=2.0
)

# end::ch14-custom-tr-es[]

# tag::ch14-new-pipeline[]
from pyspark.ml.pipeline import Pipeline

food_pipeline = Pipeline(
    stages=[
        scalar_na_filler,          # <1>
        extreme_value_capper_cal,  # <1>
        extreme_value_capper_pro,  # <1>
        extreme_value_capper_fat,  # <1>
        extreme_value_capper_sod,  # <1>
        imputer,
        continuous_assembler,
        continuous_scaler,
        preml_assembler,
        lr,
    ]
)
# end::ch14-new-pipeline[]

# tag::ch14-new-pipeline-application[]
from pyspark.ml.evaluation import BinaryClassificationEvaluator

train, test = food.randomSplit([0.7, 0.3], 13)

food_pipeline_model = food_pipeline.fit(train)

results = food_pipeline_model.transform(test)


evaluator = BinaryClassificationEvaluator(
    labelCol="dessert",
    rawPredictionCol="rawPrediction",
    metricName="areaUnderROC",
)

accuracy = evaluator.evaluate(results)
print(f"Area under ROC = {accuracy} ")
# Area under ROC = 0.9929619675735302

# end::ch14-new-pipeline-application[]

# tag::ch14-new-pipeline-save[]

food_pipeline_model.save("code/food_pipeline.model")
# ValueError: ('Pipeline write will fail on this pipeline because
# stage %s of type %s is not MLWritable',
# 'ScalarNAFiller_7fe16120b179', <class '__main__.ScalarNAFiller'>)

# end::ch14-new-pipeline-save[]

# tag::ch14-new-pipeline-read[]
from pyspark.ml.pipeline import PipelineModel
from .custom_feature import ScalarNAFiller, ExtremeValueCapperModel

food_pipeline_model.save("code/food_pipeline.model")
food_pipeline_model = PipelineModel.read().load("code/food_pipeline.model")
# end::ch14-new-pipeline-read[]
