import marimo

__generated_with = "0.7.12"
app = marimo.App(width="medium")


@app.cell
def __():
    import matplotlib.pyplot as plt
    import numpy as np
    return np, plt


@app.cell
def __(plt):
    plt.rcParams["font.family"] = "Times New Roman"
    plt.rcParams["font.size"] = 16
    return


@app.cell
def __(np, plt):
    x = np.arange(4)

    fig, ax = plt.subplots(figsize=(7, 2.6), ncols=3, layout="constrained")

    c_1 = (139 / 255, 0 / 255, 0 / 255)
    c_2 = (0, 0, 0)
    c_3 = (191 / 255, 191 / 255, 191 / 255)
    c_4 = (230 / 255, 109 / 255, 104 / 255)

    # tp = 8 * l * b * h * (d-1)/d
    # pp = 2 * b * h * (d - 1)
    token_size = 512 * 8 * 4 / 1024 / 1024 / 1024
    x = [2, 3, 4, 5, 6, 7, 8]

    l1 = 22
    h1 = 2048

    l2 = 32
    h2 = 4096

    l3 = 40
    h3 = 5120

    ticks = [2, 3, 4, 5, 6, 7, 8]
    ticks_label = ["2", "3", "4", "5", "6", "7", "8"]
    y_ticks = [0, 5, 10, 15, 20]
    y_ticks_label = ["0", "5", "10", "15", "20"]

    y_tp = [8 * l1 * token_size * h1 * (d - 1) / d for d in x]
    y_pp = [2 * token_size * h1 * (d - 1) for d in x]
    ax[0].plot(x, y_tp, color=c_2, label="TP", marker="o")
    ax[0].plot(x, y_pp, color=c_4, label="LoRAPP and 1F1B", marker="*")
    ax[0].set_ylim([-1, 25])
    ax[0].set_xticks(ticks)
    ax[0].set_xticklabels(ticks_label)
    ax[0].set_yticks(y_ticks)
    ax[0].set_yticklabels(y_ticks_label, rotation=90, ha="center", va="center")
    ax[0].set_title("(a) TinyLlama-1.1B", fontsize=16)
    ax[0].set_ylabel("Communication Cost (GB)", fontsize=16)
    ax[0].set_xlabel("Number of GPUs", fontsize=16)
    ax[0].tick_params(labelsize=16, pad=7)

    y_tp = [8 * l2 * token_size * h2 * (d - 1) / d for d in x]
    y_pp = [2 * token_size * h2 * (d - 1) for d in x]
    ax[1].plot(x, y_tp, color=c_2, marker="o")
    ax[1].plot(x, y_pp, color=c_4, marker="*")
    ax[1].set_ylim([-1, 25])
    ax[1].set_xticks(ticks)
    ax[1].set_xticklabels(ticks_label)
    ax[1].set_yticks(y_ticks)
    ax[1].set_yticklabels(y_ticks_label, rotation=90, ha="center", va="center")
    ax[1].set_title("(b) Llama-2-7B", fontsize=16)
    ax[1].set_xlabel("Number of GPUs", fontsize=16)
    ax[1].tick_params(labelsize=16, pad=7)

    y_tp = [8 * l3 * token_size * h3 * (d - 1) / d for d in x]
    y_pp = [2 * token_size * h3 * (d - 1) for d in x]
    ax[2].plot(x, y_tp, color=c_2, marker="o")
    ax[2].plot(x, y_pp, color=c_4, marker="*")
    ax[2].set_ylim([-1, 25])
    ax[2].set_xticks(ticks)
    ax[2].set_xticklabels(ticks_label)
    ax[2].set_yticks(y_ticks)
    ax[2].set_yticklabels(y_ticks_label, rotation=90, ha="center", va="center")
    ax[2].set_title("(c) Llama-2-13B", fontsize=16)
    ax[2].tick_params(labelsize=16, pad=7)
    ax[2].set_xlabel("Number of GPUs", fontsize=16)

    fig.legend(
        ncol=2,
        bbox_to_anchor=(0.8, 1.17),
        fancybox=False,
        framealpha=0.0,
        fontsize=16,
    )

    # plt.show()
    plt.savefig("pp_cmp_com_cost.pdf", bbox_inches="tight", dpi=1000)
    return (
        ax,
        c_1,
        c_2,
        c_3,
        c_4,
        fig,
        h1,
        h2,
        h3,
        l1,
        l2,
        l3,
        ticks,
        ticks_label,
        token_size,
        x,
        y_pp,
        y_ticks,
        y_ticks_label,
        y_tp,
    )


if __name__ == "__main__":
    app.run()