import marimo

__generated_with = "0.7.0"
app = marimo.App(width="medium")


@app.cell
def __():
    import matplotlib.pyplot as plt
    import numpy as np
    import random
    return np, plt, random


@app.cell
def __(plt):
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 16
    return


@app.cell(hide_code=True)
def __(plt):
    fig, ax = plt.subplots(figsize=(7, 2.4), ncols=3, layout="constrained")

    space_width = 3 / 22
    bar_width = 3 * space_width

    c_1 = (139 / 255, 0 / 255, 0 / 255)
    c_2 = (0, 0, 0)
    c_3 = (191 / 255, 191 / 255, 191 / 255)
    c_4 = (230 / 255, 109 / 255, 104 / 255)

    x_ticks = [
        bar_width,
        space_width + 3 * bar_width,
        2 * space_width + 5 * bar_width,
    ]
    x_ticks_label = ["1.1B", "7B", "13B"]

    ax[0].bar(space_width, 512 * 8 / 2812, bar_width, color=c_2)
    ax[0].bar(space_width + bar_width, 512 * 8 / 3054, bar_width, color=c_4)
    ax[0].bar(
        2 * space_width + 2 * bar_width,
        512 * 8 / 690.5880666,
        bar_width,
        color=c_2,
    )
    ax[0].bar(
        2 * space_width + 3 * bar_width,
        512 * 8 / 727.7364507,
        bar_width,
        color=c_4,
    )
    ax[0].bar(
        3 * space_width + 4 * bar_width,
        512 * 8 / 396.4798927,
        bar_width,
        color=c_2,
    )
    ax[0].bar(
        3 * space_width + 5 * bar_width,
        512 * 8 / 405.9223082,
        bar_width,
        color=c_4,
    )
    ax[0].set_xticks(x_ticks)
    ax[0].set_xticklabels(x_ticks_label, ha="center", va="center")
    ax[0].tick_params(bottom=False, labelsize=14, pad=7)
    ax[0].set_ylabel("Time (s)", fontsize=16)
    ax[0].set_title("(a) Training time", fontsize=16, pad=22)

    ax[1].bar(space_width, 10.1, bar_width, color=c_2)
    ax[1].bar(space_width + bar_width, 2.4, bar_width, color=c_4)
    ax[1].bar(2 * space_width + 2 * bar_width, 7.5, bar_width, color=c_2)
    ax[1].bar(2 * space_width + 3 * bar_width, 2.1, bar_width, color=c_4)
    ax[1].bar(3 * space_width + 4 * bar_width, 3.9, bar_width, color=c_2)
    ax[1].bar(3 * space_width + 5 * bar_width, 1.1, bar_width, color=c_4)
    ax[1].set_xticks(x_ticks)
    ax[1].set_xticklabels(x_ticks_label, ha="center", va="center")
    ax[1].set_ylabel("Percentage (%)", fontsize=16)
    ax[1].tick_params(bottom=False, labelsize=14, pad=7)
    ax[1].set_title("(b) Proportion of \nkernel launch time", fontsize=16)

    peft_kern_time = [12969965266, 56659466215, 96280798449]
    batchlora_kern_time = [12969965266, 55606549735, 96080798449]

    pt = [1.3109530583214795, 5.493869969292716, 9.928009143652595]
    bt = [1.309003274394237, 5.493329344922422, 9.919060664229837]

    ax[2].bar(space_width, pt[0], bar_width, label="PEFT", color=c_2)
    ax[2].bar(
        space_width + bar_width, bt[0], bar_width, label="BatchLoRA", color=c_4
    )
    ax[2].bar(2 * space_width + 2 * bar_width, pt[1], bar_width, color=c_2)
    ax[2].bar(2 * space_width + 3 * bar_width, bt[1], bar_width, color=c_4)
    ax[2].bar(3 * space_width + 4 * bar_width, pt[2], bar_width, color=c_2)
    ax[2].bar(3 * space_width + 5 * bar_width, bt[2], bar_width, color=c_4)
    ax[2].set_xticks(x_ticks)
    ax[2].set_xticklabels(x_ticks_label, ha="center", va="center")
    ax[2].set_ylabel("Time (s)", fontsize=16)
    ax[2].tick_params(bottom=False, labelsize=14, pad=7)
    ax[2].set_title("(c) Kernel execution time", fontsize=16, pad=22)

    fig.legend(
        ncol=2,
        bbox_to_anchor=(0.8, 1.2),
        fancybox=False,
        framealpha=0.0,
        fontsize=16,
    )

    # plt.savefig("batchlora_cmp.pdf", bbox_inches="tight", dpi=1000)
    return (
        ax,
        bar_width,
        batchlora_kern_time,
        bt,
        c_1,
        c_2,
        c_3,
        c_4,
        fig,
        peft_kern_time,
        pt,
        space_width,
        x_ticks,
        x_ticks_label,
    )


if __name__ == "__main__":
    app.run()