#!/usr/bin/env python3

from pyspark.sql import SparkSession
import pyspark.sql.functions as F

spark = SparkSession.builder.getOrCreate()

# tag::exo7_1[]
spark.sql(
    """
select count(*) as count from elements where Radioactive is not null
"""
).show()

# +-----+
# |count|
# +-----+
# |   37|
# +-----+

# end::exo7_1[]

# tag::sol7_1[]
elements.where(F.col("Radioactive").isNotNull()).groupby().count().show()

# +-----+
# |count|
# +-----+
# |   37|
# +-----+
# end::sol7_1[]

# tag::sol7_2[]

sol7_2 = (
    full_data.groupby("model", "capacity_GB").agg(
        F.sum("failure").alias("failures"),
        F.count("*").alias("drive_days"),
    )
).selectExpr("model", "capacity_GB", "failures / drive_days failure_rate")

sol7_2.show(10)
# +--------------------+--------------------+--------------------+
# |               model|         capacity_GB|        failure_rate|
# +--------------------+--------------------+--------------------+
# |       ST12000NM0117|             11176.0|0.006934812760055479|
# |      WDC WD5000LPCX|   465.7617416381836|1.013736124486796...|
# |         ST6000DX000|-9.31322574615478...|                 0.0|
# |         ST6000DM004|    5589.02986907959|                 0.0|
# |      WDC WD2500AAJS|  232.88591766357422|                 0.0|
# |         ST4000DM005|   3726.023277282715|                 0.0|
# |HGST HMS5C4040BLE641|   3726.023277282715|                 0.0|
# |       ST500LM012 HN|   465.7617416381836|2.290804285402249...|
# |       ST12000NM0008|             11176.0|3.112598241381993...|
# |HGST HUH721010ALE600|-9.31322574615478...|                 0.0|
# +--------------------+--------------------+--------------------+
# only showing top 10 rows

# end::sol7_2[]

# tag::sol7_3[]

common_columns = list(
    reduce(
        lambda x, y: x.intersection(y), [set(df.columns) for df in data]
    )
)

full_data = (
    reduce(
        lambda x, y: x.select(common_columns).union(
            y.select(common_columns)
        ),
        data,
    )
    .selectExpr(
        "serial_number",
        "model",
        "capacity_bytes / pow(1024, 3) capacity_GB",
        "date",
        "failure",
    )
    .groupby("serial_number", "model", "capacity_GB")
    .agg(
        F.datediff(
            F.max("date").cast("date"), F.min("date").cast("date")
        ).alias("age")
    )
)

sol7_3 = full_data.groupby("model", "capacity_GB").agg(
    F.avg("age").alias("avg_age")
)

sol7_3.orderBy("avg_age", ascending=False).show(10)
# +--------------------+-----------------+------------------+
# |               model|      capacity_GB|           avg_age|
# +--------------------+-----------------+------------------+
# |      ST1000LM024 HN|931.5133895874023|             364.0|
# |HGST HMS5C4040BLE641|3726.023277282715|             364.0|
# |         ST8000DM002|7452.036460876465| 361.1777375201288|
# |Seagate BarraCuda...|465.7617416381836| 360.8888888888889|
# |       ST10000NM0086|           9314.0| 357.7377450980392|
# |        ST8000NM0055|7452.036460876465|  357.033857892227|
# |      WDC WD5000BPKT|465.7617416381836| 355.3636363636364|
# |HGST HUS726040ALE610|3726.023277282715| 354.0689655172414|
# |      WDC WD5000LPCX|465.7617416381836|352.42857142857144|
# |HGST HUH728080ALE600|7452.036460876465| 349.7186311787072|
# +--------------------+-----------------+------------------+
# only showing top 10 rows

# end::sol7_3[]

# tag::sol7_4[]

common_columns = list(
    reduce(
        lambda x, y: x.intersection(y), [set(df.columns) for df in data]
    )
)

sol7_4 = (
    reduce(
        lambda x, y: x.select(common_columns).union(
            y.select(common_columns)
        ),
        data,
    )
    .selectExpr(
        "cast(date as date) as date",
        "capacity_bytes / pow(1024, 4) as capacity_TB",
    )
    .where("extract(day from date) = 1")
    .groupby("date")
    .sum("capacity_TB")
)

sol7_4.orderBy("date").show(10)
# +----------+-----------------+
# |      date| sum(capacity_TB)|
# +----------+-----------------+
# |2019-01-01|732044.6322980449|
# |2019-02-01|745229.8319376707|
# |2019-03-01|760761.8200763315|
# |2019-04-01|784048.2895324379|
# |2019-05-01| 781405.457732901|
# |2019-06-01|834218.0686636567|
# |2019-07-01|833865.5910149883|
# |2019-08-01|846133.1006234661|
# |2019-09-01|858464.0372464955|
# |2019-10-01|884306.1266535893|
# +----------+-----------------+
# only showing top 10 rows

# end::sol7_4[]

# tag::sol7_5[]

common_columns = list(
    reduce(
        lambda x, y: x.intersection(y), [set(df.columns) for df in data]
    )
)

data7_5 = reduce(
    lambda x, y: x.select(common_columns).union(y.select(common_columns)),
    data,
)

capacity_count = data7_5.groupby("model", "capacity_bytes").agg(
    F.count("*").alias("capacity_occurence")
)

most_common_capacity = capacity_count.groupby("model").agg(
    F.max("capacity_occurence").alias("most_common_capacity_occurence")
)

sol7_5 = most_common_capacity.join(
    capacity_count,
    (capacity_count["model"] == most_common_capacity["model"])
    & (
        capacity_count["capacity_occurence"]
        == most_common_capacity["most_common_capacity_occurence"]
    ),
).select(most_common_capacity["model"], "capacity_bytes")

sol7_5.show(5)
# +--------------------+--------------+
# |               model|capacity_bytes|
# +--------------------+--------------+
# |      WDC WD5000LPVX|  500107862016|
# |       ST12000NM0117|12000138625024|
# | TOSHIBA MD04ABA500V| 5000981078016|
# |HGST HUS726040ALE610| 4000787030016|
# |HGST HUH721212ALE600|12000138625024|
# +--------------------+--------------+
# only showing top 5 rows

full_data = data7_5.drop("capacity_bytes").join(sol7_5, "model")

# end::sol7_5[]
