#!/usr/bin/env python3

# tag::ch10-jaccard[]

columns_pairing = [
    (x, y) for x in BINARY_COLUMNS for y in BINARY_COLUMNS if x < y
]


def chunks(list_of_items, size_of_chunks=250):
    for i in range(0, len(list_of_items), size_of_chunks):
        yield list_of_items[i : i + size_of_chunks]


def jaccard_columns(df, list_of_pairings, threshold=0.8):
    def M11(col1, col2):
        return F.sum(F.col(col1) * F.col(col2))

    to_compute = [
        (M11(x, y) / (F.sum(x) + F.sum(y) - M11(x, y))).alias(f"{x}_{y}")
        for x, y in list_of_pairings
    ]
    jaccard_df = df.agg(*to_compute).toPandas().unstack()
    return jaccard_df[jaccard_df > threshold]


for chunk in chunks(columns_pairing):
    print(jaccard_columns(food, chunk))

# england_london  0    1.0
# kosher_pescatarian  0    0.836038
# peanut_free_soy_free       0    0.931997
# peanut_free_tree_nut_free  0    0.801144

BINARY_COLUMNS.remove("london")

# end::ch10-jaccard[]


# def compute_jaccard_for_feature(df, row, columns):

#     # We compute all the computations we'll need ahead of time to avoid having a
#     # massive group by
#     agg_sequence = []
#     jaccard_sequence = []

#     @F.pandas_udf(T.DoubleType())
#     def compute_jaccard(
#         index: pd.Series, zero: pd.Series, one: pd.Series
#     ) -> float:
#         contingencies = pd.DataFrame(
#             [zero, one], index=index, columns=[0.0, 1.0]
#         )
#         M11 = contingencies.loc[1.0, 1.0].item()
#         M01 = contingencies.loc[0.0, 1.0].item()
#         M10 = contingencies.loc[1.0, 0.0].item()
#         return M11 / (M01 + M10 + M11)

#     for column in columns:
#         agg_sequence.append(
#             (F.count(column) - F.sum(column)).alias(f"{column}_0")
#         )
#         agg_sequence.append(F.sum(column).alias(f"{column}_1"))

#         jaccard_sequence.append(
#             compute_jaccard(
#                 F.col(row), F.col(f"{column}_0"), F.col(f"{column}_1")
#             ).alias(column)
#         )

#     frequency_df = (
#         df.groupby(row).agg(*agg_sequence).select(*jaccard_sequence)
#     )

#     return frequency_df
