#!/usr/bin/env python3
"""This is the code for chapter 5 of the book PySpark in Action."""

# tag::relevant-imports[]

import os

from pyspark.sql import SparkSession
from pyspark.sql.utils import AnalysisException
import pyspark.sql.functions as F

spark = SparkSession.builder.getOrCreate()

DIRECTORY = "./data/broadcast_logs"
logs = spark.read.csv(
    os.path.join(DIRECTORY, "BroadcastLogs_2018_Q3_M8.CSV"),
    sep="|",
    header=True,
    inferSchema=True,
).withColumn(
    "duration_seconds",
    (
        F.col("Duration").substr(1, 2).cast("int") * 60 * 60
        + F.col("Duration").substr(4, 2).cast("int") * 60
        + F.col("Duration").substr(7, 2).cast("int")
    ),
)

# end::relevant-imports[]

# tag::ch05-reading-logidentifier[]

DIRECTORY = "./data/broadcast_logs"
log_identifier = spark.read.csv(
    os.path.join(DIRECTORY, "ReferenceTables/LogIdentifier.csv"),
    sep="|",
    header=True,
    inferSchema=True,
)

log_identifier.printSchema()
# root
#  |-- LogIdentifierID: string (nullable = true) # <1>
#  |-- LogServiceID: integer (nullable = true) # <2>
#  |-- PrimaryFG: integer (nullable = true) # <3>

log_identifier = log_identifier.where(F.col("PrimaryFG") == 1)
print(log_identifier.count())
# 758

log_identifier.show(5)
# +---------------+------------+---------+
# |LogIdentifierID|LogServiceID|PrimaryFG|
# +---------------+------------+---------+
# |           13ST|        3157|        1|
# |         2000SM|        3466|        1|
# |           70SM|        3883|        1|
# |           80SM|        3590|        1|
# |           90SM|        3470|        1|
# +---------------+------------+---------+
# only showing top 5 rows


# end::ch05-reading-logidentifier[]


# tag::ch05-verbose-join[]

logs_and_channels_verbose = logs.join(
    log_identifier, logs["LogServiceID"] == log_identifier["LogServiceID"]
)

logs_and_channels_verbose.printSchema()

# root
#  |-- LogServiceID: integer (nullable = true) <1>
#  |-- LogDate: timestamp (nullable = true)
#  |-- AudienceTargetAgeID: integer (nullable = true)
#  |-- AudienceTargetEthnicID: integer (nullable = true)
#  [...]
#  |-- duration_seconds: integer (nullable = true)
#  |-- LogIdentifierID: string (nullable = true)
#  |-- LogServiceID: integer (nullable = true) <2>
#  |-- PrimaryFG: integer (nullable = true)

try:
    logs_and_channels_verbose.select("LogServiceID")
except AnalysisException as err:
    print(err)

# "Reference 'LogServiceID' is ambiguous, could be: LogServiceID, LogServiceID.;" <3>

# end::ch05-verbose-join[]

# tag::ch05-simplified-join[]

logs_and_channels = logs.join(log_identifier, "LogServiceID")

logs_and_channels.printSchema()

# root
#  |-- LogServiceID: integer (nullable = true)
#  |-- LogDate: timestamp (nullable = true)
#  |-- AudienceTargetAgeID: integer (nullable = true)
#  |-- AudienceTargetEthnicID: integer (nullable = true)
#  |-- CategoryID: integer (nullable = true)
#  [...]
#  |-- Language2: integer (nullable = true)
#  |-- duration_seconds: integer (nullable = true)
#  |-- LogIdentifierID: string (nullable = true)  <1>
#  |-- PrimaryFG: integer (nullable = true)       <1>

# end::ch05-simplified-join[]

# tag::ch05-join-full-name[]

logs_and_channels_verbose = logs.join(
    log_identifier, logs["LogServiceID"] == log_identifier["LogServiceID"]
)

logs_and_channels.drop(log_identifier["LogServiceID"]).select(
    "LogServiceID")  # <1>

# DataFrame[LogServiceID: int]

# end::ch05-join-full-name[]
# tag::ch05-join-alias[]
logs_and_channels_verbose = logs.alias("left").join(  # <1>
    log_identifier.alias("right"),  # <2>
    logs["LogServiceID"] == log_identifier["LogServiceID"],
)

logs_and_channels_verbose.drop(F.col("right.LogServiceID")).select(
    "LogServiceID"
)  # <3>

# DataFrame[LogServiceID: int]


# end::ch05-join-alias[]

# tag::ch05-data-building[]

DIRECTORY = "./data/broadcast_logs"

cd_category = spark.read.csv(
    os.path.join(DIRECTORY, "ReferenceTables/CD_Category.csv"),
    sep="|",
    header=True,
    inferSchema=True,
).select(
    "CategoryID",
    "CategoryCD",
    F.col("EnglishDescription").alias("Category_Description"),  # <1>
)

cd_program_class = spark.read.csv(
    os.path.join(DIRECTORY, "ReferenceTables/CD_ProgramClass.csv"),
    sep="|",
    header=True,
    inferSchema=True,
).select(
    "ProgramClassID",
    "ProgramClassCD",
    F.col("EnglishDescription").alias("ProgramClass_Description"),  # <2>
)

full_log = logs_and_channels.join(cd_category, "CategoryID", how="left").join(
    cd_program_class, "ProgramClassID", how="left"
)

# end::ch05-data-building[]

# tag::ch05-simple-counting[]

(full_log
 .groupby("ProgramClassCD", "ProgramClass_Description")
 .agg(F.sum("duration_seconds").alias("duration_total"))
 .orderBy("duration_total", ascending=False).show(100, False)
 )

# +--------------+--------------------------------------+--------------+
# |ProgramClassCD|ProgramClass_Description              |duration_total|
# +--------------+--------------------------------------+--------------+
# |PGR           |PROGRAM                               |652802250     |
# |COM           |COMMERCIAL MESSAGE                    |106810189     |
# |PFS           |PROGRAM FIRST SEGMENT                 |38817891      |
# |SEG           |SEGMENT OF A PROGRAM                  |34891264      |
# |PRC           |PROMOTION OF UPCOMING CANADIAN PROGRAM|27017583      |
# |PGI           |PROGRAM INFOMERCIAL                   |23196392      |
# |PRO           |PROMOTION OF NON-CANADIAN PROGRAM     |10213461      |
# |OFF           |SCHEDULED OFF AIR TIME PERIOD         |4537071       |
# [... more rows]
# |COR           |CORNERSTONE                           |null          |
# +--------------+--------------------------------------+--------------+

# end::ch05-simple-counting[]

# tag::ch05-dict-counting[]

full_log.groupby("ProgramClassCD", "ProgramClass_Description").agg(
    {"duration_seconds": "sum"}
).withColumnRenamed("sum(duration_seconds)", "duration_total").orderBy(
    "duration_total", ascending=False
).show(
    100, False
)

# end::ch05-dict-counting[]

# tag::ch05-grouped-data[]

full_log.groupby()
# <pyspark.sql.group.GroupedData at 0x119baa4e0>

# end::ch05-grouped-data[]
# tag::ch05-selected-classes-code[]

F.when(
    F.trim(F.col("ProgramClassCD")).isin(
        ["COM", "PRC", "PGI", "PRO", "PSA", "MAG", "LOC", "SPO", "MER", "SOL"]
    ),
    F.col("duration_seconds"),
).otherwise(0)

# end::ch05-selected-classes-code[]

# fmt:off
# tag::ch05-final[]
answer = (
    full_log.groupby("LogIdentifierID")
    .agg(
        F.sum(                                                                # <1>
            F.when(                                                           # <1>
                F.trim(F.col("ProgramClassCD")).isin(                         # <1>
                    ["COM", "PRC", "PGI", "PRO", "LOC", "SPO", "MER", "SOL"]  # <1>
                ),                                                            # <1>
                F.col("duration_seconds"),                                    # <1>
            ).otherwise(0)                                                    # <1>
        ).alias("duration_commercial"),                                       # <1>
        F.sum("duration_seconds").alias("duration_total"),
    )
    .withColumn(
        "commercial_ratio", F.col(
            "duration_commercial") / F.col("duration_total")
    )
)

answer.orderBy("commercial_ratio", ascending=False).show(1000, False)

# +---------------+-------------------+--------------+---------------------+
# |LogIdentifierID|duration_commercial|duration_total|commercial_ratio     |
# +---------------+-------------------+--------------+---------------------+
# |HPITV          |403                |403           |1.0                  |
# |TLNSP          |234455             |234455        |1.0                  |
# |MSET           |101670             |101670        |1.0                  |
# |TELENO         |545255             |545255        |1.0                  |
# |CIMT           |19935              |19935         |1.0                  |
# |TANG           |271468             |271468        |1.0                  |
# |INVST          |623057             |633659        |0.9832686034602207   |
# [...]
# |OTN3           |0                  |2678400       |0.0                  |
# |PENT           |0                  |2678400       |0.0                  |
# |ATN14          |0                  |2678400       |0.0                  |
# |ATN11          |0                  |2678400       |0.0                  |
# |ZOOM           |0                  |2678400       |0.0                  |
# |EURO           |0                  |null          |null                 |
# |NINOS          |0                  |null          |null                 |
# +---------------+-------------------+--------------+---------------------+
# end::ch05-final[]
# fmt:on

# tag::ch05-drop1[]
answer_no_null = answer.dropna(subset=["commercial_ratio"])

answer_no_null.orderBy(
    "commercial_ratio", ascending=False).show(1000, False)

# +---------------+-------------------+--------------+---------------------+
# |LogIdentifierID|duration_commercial|duration_total|commercial_ratio     |
# +---------------+-------------------+--------------+---------------------+
# |HPITV          |403                |403           |1.0                  |
# |TLNSP          |234455             |234455        |1.0                  |
# |MSET           |101670             |101670        |1.0                  |
# |TELENO         |545255             |545255        |1.0                  |
# |CIMT           |19935              |19935         |1.0                  |
# |TANG           |271468             |271468        |1.0                  |
# |INVST          |623057             |633659        |0.9832686034602207   |
# [...]
# |OTN3           |0                  |2678400       |0.0                  |
# |PENT           |0                  |2678400       |0.0                  |
# |ATN14          |0                  |2678400       |0.0                  |
# |ATN11          |0                  |2678400       |0.0                  |
# |ZOOM           |0                  |2678400       |0.0                  |
# +---------------+-------------------+--------------+---------------------+

print(answer_no_null.count())  # 322
# end::ch05-drop1[]

# tag::ch05-fill1[]
answer_no_null = answer.fillna(0)

answer_no_null.orderBy(
    "commercial_ratio", ascending=False).show(1000, False)

# +---------------+-------------------+--------------+---------------------+
# |LogIdentifierID|duration_commercial|duration_total|commercial_ratio     |
# +---------------+-------------------+--------------+---------------------+
# |HPITV          |403                |403           |1.0                  |
# |TLNSP          |234455             |234455        |1.0                  |
# |MSET           |101670             |101670        |1.0                  |
# |TELENO         |545255             |545255        |1.0                  |
# |CIMT           |19935              |19935         |1.0                  |
# |TANG           |271468             |271468        |1.0                  |
# |INVST          |623057             |633659        |0.9832686034602207   |
# [...]
# |OTN3           |0                  |2678400       |0.0                  |
# |PENT           |0                  |2678400       |0.0                  |
# |ATN14          |0                  |2678400       |0.0                  |
# |ATN11          |0                  |2678400       |0.0                  |
# |ZOOM           |0                  |2678400       |0.0                  |
# +---------------+-------------------+--------------+---------------------+

print(answer_no_null.count())  # 324 <1>
# end::ch05-fill1[]

# tag::ch05-fill2[]

answer_no_null = answer.fillna(
    {"duration_commercial": 0, "duration_total": 0, "commercial_ratio": 0}
)

# end::ch05-fill2[]
