"""Code for the book PySpark in Action, Chapter 4."""

# tag::relevant-imports[]

from pyspark.sql import SparkSession
import pyspark.sql.functions as F

spark = SparkSession.builder.getOrCreate()

# end::relevant-imports[]


# tag::ch04-grocery-list[]

my_grocery_list = [
    ["Banana", 2, 1.74],
    ["Apple", 4, 2.04],
    ["Carrot", 1, 1.09],
    ["Cake", 1, 10.99],
]  # <1>

df_grocery_list = spark.createDataFrame(
    my_grocery_list, ["Item", "Quantity", "Price"]
)

df_grocery_list.printSchema()
# root
#  |-- Item: string (nullable = true)   <2>
#  |-- Quantity: long (nullable = true) <2>
#  |-- Price: double (nullable = true)  <2>

# end::ch04-grocery-list[]

# tag::ch04-reading-dsv[]

import os

DIRECTORY = "./data/broadcast_logs"
logs = spark.read.csv(
    os.path.join(DIRECTORY, "BroadcastLogs_2018_Q3_M8.CSV"),  # <1>
    sep="|",  # <2>
    header=True,  # <3>
    inferSchema=True,  # <4>
    timestampFormat="yyyy-MM-dd",  # <5>
)

# end::ch04-reading-dsv[]

# tag::ch04-reading-dsv2[]
logs = spark.read.csv(
    path=os.path.join(DIRECTORY, "BroadcastLogs_2018_Q3_M8.CSV"),
    sep="|",
    header=True,
    inferSchema=True,
    timestampFormat="yyyy-MM-dd",
)
# end::ch04-reading-dsv2[]


# tag::ch04-print-schema[]

logs.printSchema()
# root
#  |-- BroadcastLogID: integer (nullable = true)
#  |-- LogServiceID: integer (nullable = true)
#  |-- LogDate: timestamp (nullable = true)
#  |-- SequenceNO: integer (nullable = true)
#  |-- AudienceTargetAgeID: integer (nullable = true)
#  |-- AudienceTargetEthnicID: integer (nullable = true)
#  |-- CategoryID: integer (nullable = true)
#  |-- ClosedCaptionID: integer (nullable = true)
#  |-- CountryOfOriginID: integer (nullable = true)
#  |-- DubDramaCreditID: integer (nullable = true)
#  |-- EthnicProgramID: integer (nullable = true)
#  |-- ProductionSourceID: integer (nullable = true)
#  |-- ProgramClassID: integer (nullable = true)
#  |-- FilmClassificationID: integer (nullable = true)
#  |-- ExhibitionID: integer (nullable = true)
#  |-- Duration: string (nullable = true)
#  |-- EndTime: string (nullable = true)
#  |-- LogEntryDate: timestamp (nullable = true)
#  |-- ProductionNO: string (nullable = true)
#  |-- ProgramTitle: string (nullable = true)
#  |-- StartTime: string (nullable = true)
#  |-- Subtitle: string (nullable = true)
#  |-- NetworkAffiliationID: integer (nullable = true)
#  |-- SpecialAttentionID: integer (nullable = true)
#  |-- BroadcastOriginPointID: integer (nullable = true)
#  |-- CompositionID: integer (nullable = true)
#  |-- Producer1: string (nullable = true)
#  |-- Producer2: string (nullable = true)
#  |-- Language1: integer (nullable = true)
#  |-- Language2: integer (nullable = true)

# end::ch04-print-schema[]

# tag::ch04-simple-selection[]

logs.select("BroadcastLogID", "LogServiceID", "LogDate").show(5, False)

# +--------------+------------+-------------------+
# |BroadcastLogID|LogServiceID|LogDate            |
# +--------------+------------+-------------------+
# |1196192316    |3157        |2018-08-01 00:00:00|
# |1196192317    |3157        |2018-08-01 00:00:00|
# |1196192318    |3157        |2018-08-01 00:00:00|
# |1196192319    |3157        |2018-08-01 00:00:00|
# |1196192320    |3157        |2018-08-01 00:00:00|
# +--------------+------------+-------------------+
# only showing top 5 rows

# end::ch04-simple-selection[]

# tag::ch04-select-equivalence[]

# Using the string to column conversion
logs.select("BroadCastLogID", "LogServiceID", "LogDate")
logs.select(*["BroadCastLogID", "LogServiceID", "LogDate"])

# Passing the column object explicitly
logs.select(
    F.col("BroadCastLogID"), F.col("LogServiceID"), F.col("LogDate")
)
logs.select(
    *[F.col("BroadCastLogID"), F.col("LogServiceID"), F.col("LogDate")]
)

# end::ch04-select-equivalence[]

# tag::ch04-select-slice[]

import numpy as np

column_split = np.array_split(
    np.array(logs.columns), len(logs.columns) // 3
)  # <1>

print(column_split)

# [array(['BroadcastLogID', 'LogServiceID', 'LogDate'], dtype='<U22'),
#  [...]
#  array(['Producer2', 'Language1', 'Language2'], dtype='<U22')]'

for x in column_split:
    logs.select(*x).show(5, False)

# +--------------+------------+-------------------+
# |BroadcastLogID|LogServiceID|LogDate            |
# +--------------+------------+-------------------+
# |1196192316    |3157        |2018-08-01 00:00:00|
# |1196192317    |3157        |2018-08-01 00:00:00|
# |1196192318    |3157        |2018-08-01 00:00:00|
# |1196192319    |3157        |2018-08-01 00:00:00|
# |1196192320    |3157        |2018-08-01 00:00:00|
# +--------------+------------+-------------------+
# only showing top 5 rows
# ... and more tables of 3 columns

# end::ch04-select-slice[]

# tag::ch04-drop[]

logs = logs.drop("BroadcastLogID", "SequenceNO")

# Testing if we effectively got rid of the columns

print("BroadcastLogID" in logs.columns)  # => False
print("SequenceNo" in logs.columns)  # => False

# end::ch04-drop[]

# tag::ch04-drop-select[]

logs = logs.select(
    *[x for x in logs.columns if x not in ["BroadcastLogID", "SequenceNO"]]
)

# end::ch04-drop-select[]

# tag::ch04-duration-display[]

logs.select(F.col("Duration")).show(5)

# +----------------+
# |        Duration|
# +----------------+
# |02:00:00.0000000|
# |00:00:30.0000000|
# |00:00:15.0000000|
# |00:00:15.0000000|
# |00:00:15.0000000|
# +----------------+
# only showing top 5 rows

print(logs.select(F.col("Duration")).dtypes)  # <1>

# [('Duration', 'string')]

# end::ch04-duration-display[]

# tag::ch04-duration-substr[]

logs.select(
    F.col("Duration"),  # <1>
    F.col("Duration").substr(1, 2).cast("int").alias("dur_hours"),  # <2>
    F.col("Duration").substr(4, 2).cast("int").alias("dur_minutes"),  # <3>
    F.col("Duration").substr(7, 2).cast("int").alias("dur_seconds"),  # <4>
).distinct().show(  # <5>
    5
)

# +----------------+---------+-----------+-----------+
# |        Duration|dur_hours|dur_minutes|dur_seconds|
# +----------------+---------+-----------+-----------+
# |00:10:06.0000000|        0|         10|          6|
# |00:10:37.0000000|        0|         10|         37|
# |00:04:52.0000000|        0|          4|         52|
# |00:26:41.0000000|        0|         26|         41|
# |00:08:18.0000000|        0|          8|         18|
# +----------------+---------+-----------+-----------+
# only showing top 5 rows

# end::ch04-duration-substr[]

# tag::ch04-duration-field[]

logs.select(
    F.col("Duration"),
    (
        F.col("Duration").substr(1, 2).cast("int") * 60 * 60
        + F.col("Duration").substr(4, 2).cast("int") * 60
        + F.col("Duration").substr(7, 2).cast("int")
    ).alias("Duration_seconds"),
).distinct().show(5)

# +----------------+----------------+
# |        Duration|Duration_seconds|
# +----------------+----------------+
# |00:10:30.0000000|             630|
# |00:25:52.0000000|            1552|
# |00:28:08.0000000|            1688|
# |06:00:00.0000000|           21600|
# |00:32:08.0000000|            1928|
# +----------------+----------------+
# only showing top 5 rows

# end::ch04-duration-field[]

# tag::ch04-duration-with-column[]

logs = logs.withColumn(
    "Duration_seconds",
    (
        F.col("Duration").substr(1, 2).cast("int") * 60 * 60
        + F.col("Duration").substr(4, 2).cast("int") * 60
        + F.col("Duration").substr(7, 2).cast("int")
    ),
)

logs.printSchema()

# root
#  |-- LogServiceID: integer (nullable = true)
#  |-- LogDate: timestamp (nullable = true)
#  |-- AudienceTargetAgeID: integer (nullable = true)
#  |-- AudienceTargetEthnicID: integer (nullable = true)
#  [... more columns]
#  |-- Language2: integer (nullable = true)
#  |-- Duration_seconds: integer (nullable = true)  <1>

# end::ch04-duration-with-column[]

# tag::ch04-with-column-renamed[]

logs = logs.withColumnRenamed("Duration_seconds", "duration_seconds")

logs.printSchema()

# root
#  |-- LogServiceID: integer (nullable = true)
#  |-- LogDate: timestamp (nullable = true)
#  |-- AudienceTargetAgeID: integer (nullable = true)
#  |-- AudienceTargetEthnicID: integer (nullable = true)
#  [...]
#  |-- Language2: integer (nullable = true)
#  |-- duration_seconds: integer (nullable = true)

# end::ch04-with-column-renamed[]

# tag::ch04-batch-rename[]

logs.toDF(*[x.lower() for x in logs.columns]).printSchema()

# root
#  |-- logserviceid: integer (nullable = true)
#  |-- logdate: timestamp (nullable = true)
#  |-- audiencetargetageid: integer (nullable = true)
#  |-- audiencetargetethnicid: integer (nullable = true)
#  |-- categoryid: integer (nullable = true)
#  [...]
#  |-- language2: integer (nullable = true)
#  |-- duration_seconds: integer (nullable = true)

# end::ch04-batch-rename[]

# tag::ch04-sorting-columns[]

logs.select(sorted(logs.columns)).printSchema()

# root
#  |-- AudienceTargetAgeID: integer (nullable = true)
#  |-- AudienceTargetEthnicID: integer (nullable = true)
#  |-- BroadcastOriginPointID: integer (nullable = true)
#  |-- CategoryID: integer (nullable = true)
#  |-- ClosedCaptionID: integer (nullable = true)
#  |-- CompositionID: integer (nullable = true)
#  [...]
#  |-- Subtitle: string (nullable = true)
#  |-- duration_seconds: integer (nullable = true) <1>

# end::ch04-sorting-columns[]

# tag::ch04-describe[]

for i in logs.columns:
    logs.describe(i).show()

# +-------+------------------+ <1>
# |summary|      LogServiceID|
# +-------+------------------+
# |  count|           7169318|
# |   mean|3453.8804215407936|
# | stddev|200.44137201584468|
# |    min|              3157|
# |    max|              3925|
# +-------+------------------+
#
# [...]
#
# +-------+ <2>
# |summary|
# +-------+
# |  count|
# |   mean|
# | stddev|
# |    min|
# |    max|
# +-------+

# [... many more little tables]

# end::ch04-describe[]

# tag::ch04-summarize[]

for i in logs.columns:
    logs.select(i).summary().show()  # <1>

# +-------+------------------+
# |summary|      LogServiceID|
# +-------+------------------+
# |  count|           7169318|
# |   mean|3453.8804215407936|
# | stddev|200.44137201584468|
# |    min|              3157|
# |    25%|              3291|
# |    50%|              3384|
# |    75%|              3628|
# |    max|              3925|
# +-------+------------------+
#
# [... many more slightly larger tables]

for i in logs.columns:
    logs.select(i).summary("min", "10%", "90%", "max").show()  # <2>

# +-------+------------+
# |summary|LogServiceID|
# +-------+------------+
# |    min|        3157|
# |    10%|        3237|
# |    90%|        3710|
# |    max|        3925|
# +-------+------------+
#
# [...]
# end::ch04-summarize[]
