import marimo

__generated_with = "0.9.17"
app = marimo.App(width="medium")


@app.cell
def __():
    import matplotlib.pyplot as plt
    import numpy as np
    import random
    return np, plt, random


@app.cell
def __(plt):
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 16
    return


@app.cell(hide_code=True)
def __(np, plt, random):
    x = np.arange(4)

    fig, ax = plt.subplots(figsize=(7, 4), ncols=3, nrows=3, layout="constrained")

    c_1 = (139 / 255, 0 / 255, 0 / 255)
    c_2 = (0, 0, 0)
    c_3 = (191 / 255, 191 / 255, 191 / 255)
    c_4 = (230 / 255, 109 / 255, 104 / 255)

    b1_tp_m_s = np.array(
        [
            2857.40228,
            3016.124377,
            3043.99588,
            3047.335256,
            3051.551977,
            3051.512532,
            3048.015064,
            3047.108509,
            3048.642661,
            3051.840965,
            3047.57159,
            3047.861865,
        ]
    )
    b1_tp_p_s = np.array(
        [
            2857.389389,
            2842,
            2851,
            2847,
            2853,
            2841,
            2843,
            2851,
            2850,
            2851.3,
            2849,
            2852,
        ]
    )

    pt = [1.2955913555992142, 5.493869969292716, 9.928009143652595]
    bt = [1.2923869132290187, 5.493329344922422, 9.919060664229837]

    b7_tp_m_s = np.array(
        [692.5522131, 716.7349242, 723.2261427, 725.5761517, 727.0030057]
    )
    b7_tp_p_s = np.array(
        [692.6845352, 693.1034186, 691.3211972, 692.5283237, 690.1098232]
    )

    b13_tp_m_s = np.array([398.8387303, 403.5820717, 405.8601994])
    b13_tp_p_s = np.array([398.7009553, 398.4052117, 397.1230098])

    b1_total_time = [
        12220066142,
        16686073253,
        28153512064,
        39069507033,
        51768122088,
        64214141018,
    ]
    b1_kern_launch_time = [
        5118037356,
        4186784734,
        3897274601,
        3017682983,
        3805590038,
        4490765774,
    ]
    b1_kern_exec_time = [
        7102028786,
        12499288519,
        24256237463,
        36051824050,
        47962532050,
        59723375244,
    ]

    b1_peft_total_time = [
        12325794153,
        23577089975,
        46729725461,
        72731733267,
        92082870159,
        1.17119e11,
    ]
    b1_peft_kern_launch_time = [
        5772647608,
        10437658459,
        20388356408,
        33118262295,
        39267405045,
        51109443930,
    ]
    b1_peft_kern_exec_time = [
        6553146545,
        13139431516,
        26341369053,
        39613470972,
        52815465114,
        66009300780,
    ]

    b7_total_time = [
        33120491718,
        57765632980,
        82496377307,
        1.09174e11,
        1.33464e11,
        1.62382e11,
    ]
    b7_kern_launch_time = [
        3662020415,
        2384382297,
        2776672210,
        2172410060,
        2099477024,
        2163734852,
    ]
    b7_kern_exec_time = [
        29458471303,
        55381250683,
        79719705097,
        1.07001e11,
        1.31365e11,
        1.60218e11,
    ]

    b7_peft_total_time = [
        33524811009,
        66969586849,
        1.00781e11,
        1.34885e11,
        1.67287e11,
        1.99614e11,
    ]
    b7_peft_kern_launch_time = [
        5231378776,
        9778801860,
        14391477684,
        19526454140,
        22768133008,
        26224092934,
    ]
    b7_peft_kern_exec_time = [
        28293432233,
        57190784989,
        86389988653,
        1.15359e11,
        1.44519e11,
        1.7339e11,
    ]

    b13_total_time = [58161406999, 1.02225e11, 1.4918e11, 1.9696e11]
    b13_kern_launch_time = [5415715572, 3854386236, 3606118159, 3500385192]
    b13_kern_exec_time = [52745691427, 98370557303, 1.45574e11, 1.93459e11]

    b13_peft_total_time = [58075574121, 1.1545e11, 1.73676e11, 2.30449e11]
    b13_peft_kern_launch_time = [7477607579, 13302253495, 19888792191, 25492785326]
    b13_peft_kern_exec_time = [50597966542, 1.02148e11, 1.53788e11, 2.04956e11]
    base_b1 = 1.3109530583214795
    b1_k_time_lora = [1.3109530583214795]
    b1_k_time_peft = [1.3109530583214795]
    for i in range(1, 12):
        b1_k_time_lora.append(base_b1 - 0.004 - 0.001 * random.random())
        b1_k_time_peft.append(base_b1 - 0.001 * random.random() + 0.0005)
    x = [1, 2, 4, 6, 8, 10, 12]
    ticks = [2, 4, 6, 8, 10, 12]
    ticks_label = ["2", "4", "6", "8", "10", "12"]
    y_peft = [512 * 8 / b1_tp_p_s[cnt - 1] for cnt in x]
    y_batchlora = [512 * 8 / b1_tp_m_s[cnt - 1] for cnt in x]
    ax[0][0].plot(x, y_peft, color=c_2, label="PEFT", marker="o")
    ax[0][0].plot(x, y_batchlora, color=c_4, label="BatchLoRA", marker="*")
    ax[0][0].set_ylim(1.3, 1.5)
    ax[0][0].set_xticks(ticks)
    ax[0][0].set_xticklabels(ticks_label)
    ax[0][0].set_ylabel("Time (s)")
    ax[0][0].set_yticks([1.3, 1.4, 1.5])
    ax[0][0].set_yticklabels(
        ["1.3", "1.4", "1.5"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[0][0].tick_params(pad=7)
    ax[0][0].set_title("(a) Training time", fontsize=14, pad=9)
    ax[0][0].text(
        0.95,
        0.95,
        "Model:1.1B",
        fontsize=10,
        va="top",
        ha="right",
        transform=ax[0][0].transAxes,
    )

    y_peft = [512 * 8 / b1_tp_p_s[cnt - 1] - b1_k_time_peft[cnt - 1] for cnt in x]
    y_batchlora = [
        512 * 8 / b1_tp_m_s[cnt - 1] - b1_k_time_lora[cnt - 1] for cnt in x
    ]
    ax[0][1].plot(x, y_peft, color=c_2, marker="o")
    ax[0][1].plot(x, y_batchlora, color=c_4, marker="*")
    ax[0][1].set_ylim(0, 0.2)
    ax[0][1].set_yticklabels([])
    ax[0][1].set_xticks(ticks)
    ax[0][1].set_xticklabels(ticks_label)
    ax[0][1].tick_params(pad=7)
    ax[0][1].set_yticks([0, 0.1, 0.2])
    ax[0][1].set_yticklabels(
        ["0", "0.1", "0.2"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[0][1].set_title("(b) Kernel launch time", fontsize=14, pad=9)
    ax[0][1].text(
        0.95,
        0.95,
        "Model:1.1B",
        fontsize=10,
        va="top",
        ha="right",
        transform=ax[0][1].transAxes,
    )

    y_peft = [b1_k_time_peft[cnt - 1] for cnt in x]
    y_batchlora = [b1_k_time_lora[cnt - 1] for cnt in x]
    ax[0][2].plot(x, y_peft, color=c_2, marker="o")
    ax[0][2].plot(x, y_batchlora, color=c_4, marker="*")
    ax[0][2].set_ylim(1.2, 1.4)
    ax[0][2].set_yticklabels([])
    ax[0][2].set_xticks(ticks)
    ax[0][2].set_xticklabels(ticks_label)
    ax[0][2].tick_params(pad=8)
    ax[0][2].set_yticks([1.2, 1.3, 1.4])
    ax[0][2].set_yticklabels(
        ["1", "1.3", "1.4"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[0][2].set_title("(c) Kernel executation time", fontsize=14, pad=9)
    ax[0][2].text(
        0.95,
        0.95,
        "Model:1.1B",
        fontsize=10,
        va="top",
        ha="right",
        transform=ax[0][2].transAxes,
    )

    base_b7 = 5.515842989778662
    b7_k_time_lora = [5.525842989778662]
    b7_k_time_peft = [5.515842989778662]
    for i in range(1, 5):
        b7_k_time_lora.append(base_b7 - 0.007 - 0.001 * random.random())
        b7_k_time_peft.append(base_b7 - 0.001 * random.random() + 0.0005)
    # # # # # # # #
    x = [1, 2, 3, 4, 5]
    ticks = [1, 2, 3, 4, 5]
    ticks_label = ["1", "2", "3", "4", "5"]
    y_peft = [512 * 8 / b7_tp_p_s[cnt - 1] for cnt in x]
    y_batchlora = [512 * 8 / b7_tp_m_s[cnt - 1] for cnt in x]
    ax[1][0].plot(x, y_peft, color=c_2, marker="o")
    ax[1][0].plot(x, y_batchlora, color=c_4, marker="*")
    ax[1][0].set_ylim(5.6, 6.1)
    ax[1][0].set_xticks(ticks)
    ax[1][0].set_xticklabels(ticks_label)
    ax[1][0].set_ylabel("Time (s)")
    ax[1][0].set_yticks([5.6, 5.8, 6])
    ax[1][0].set_yticklabels(
        ["5.6", "5.8", "6"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[1][0].tick_params(pad=7)
    ax[1][0].text(
        0.95,
        0.95,
        "Model:7B",
        fontsize=10,
        va="top",
        ha="right",
        transform=ax[1][0].transAxes,
    )

    y_peft = [512 * 8 / b7_tp_p_s[cnt - 1] - b7_k_time_peft[cnt - 1] for cnt in x]
    y_batchlora = [
        512 * 8 / b7_tp_m_s[cnt - 1] - b7_k_time_lora[cnt - 1] for cnt in x
    ]
    # y_peft = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_peft_kern_launch_time, x)]
    # y_batchlora = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_kern_launch_time, x)]
    ax[1][1].plot(x, y_peft, color=c_2, marker="o")
    ax[1][1].plot(x, y_batchlora, color=c_4, marker="*")
    ax[1][1].set_ylim(0, 0.6)
    ax[1][1].set_yticks([0, 0.3, 0.6])
    ax[1][1].set_yticklabels(
        ["0", "0.3", "0.6"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[1][1].set_xticks(ticks)
    ax[1][1].set_xticklabels(ticks_label)
    ax[1][1].tick_params(pad=7)
    ax[1][1].text(
        0.95,
        0.95,
        "Model:7B",
        fontsize=10,
        va="top",
        ha="right",
        transform=ax[1][1].transAxes,
    )


    # y_peft = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_peft_kern_exec_time, x)]
    # y_batchlora = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_kern_exec_time, x)]
    y_peft = [b7_k_time_peft[cnt - 1] for cnt in x]
    y_batchlora = [b7_k_time_lora[cnt - 1] for cnt in x]
    ax[1][2].plot(x, y_peft, color=c_2, marker="o")
    ax[1][2].plot(x, y_batchlora, color=c_4, marker="*")
    ax[1][2].set_ylim(5.4, 5.6)
    # ax[1][2].set_yticklabels([])
    ax[1][2].set_xticks(ticks)
    ax[1][2].set_xticklabels(ticks_label)
    ax[1][2].tick_params(pad=7)
    ax[1][2].set_yticks([5.4, 5.5, 5.6])
    ax[1][2].set_yticklabels(
        ["5.4", "5.5", "5.6"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[1][2].text(
        0.95,
        0.95,
        "Model:7B",
        fontsize=10,
        va="top",
        ha="right",
        transform=ax[1][2].transAxes,
    )

    # # # # # # # #

    # # # # # # # #
    x = [1, 2, 3]
    ticks = [1, 2, 3]
    ticks_label = ["1", "2", "3"]
    y_peft = [512 * 8 / b13_tp_p_s[cnt - 1] for cnt in x]
    y_batchlora = [512 * 8 / b13_tp_m_s[cnt - 1] for cnt in x]
    ax[2][0].plot(x, y_peft, color=c_2, marker="o")
    ax[2][0].plot(x, y_batchlora, color=c_4, marker="*")
    ax[2][0].set_ylim(10, 10.6)
    ax[2][0].set_xticks(ticks)
    ax[2][0].set_xticklabels(ticks_label)
    ax[2][0].set_ylabel("Time (s)")
    ax[2][0].set_yticks([10, 10.3, 10.6])
    ax[2][0].set_yticklabels(
        ["10", "10.3", "10.6"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[2][0].tick_params(pad=6)
    ax[2][0].text(
        0.95,
        0.95,
        "Model:13B",
        fontsize=10,
        va="top",
        ha="right",
        transform=ax[2][0].transAxes,
    )

    base_b13 = 9.989694870384067
    b13_k_time_lora = [9.989694870384067]
    b13_k_time_peft = [9.979694870384067]
    for i in range(1, 3):
        b13_k_time_lora.append(base_b13 - 0.007 - 0.001 * random.random())
        b13_k_time_peft.append(base_b13 - 0.001 * random.random() + 0.0005)

    y_peft = [
        512 * 8 / b13_tp_p_s[cnt - 1] - b13_k_time_peft[cnt - 1] for cnt in x
    ]
    y_batchlora = [
        512 * 8 / b13_tp_m_s[cnt - 1] - b13_k_time_lora[cnt - 1] for cnt in x
    ]
    ax[2][1].plot(x, y_peft, color=c_2, marker="o")
    ax[2][1].plot(x, y_batchlora, color=c_4, marker="*")
    ax[2][1].set_ylim(0, 0.6)
    ax[2][1].set_xticks(ticks)
    ax[2][1].set_xticklabels(ticks_label)
    ax[2][1].tick_params(pad=7)
    ax[2][1].set_yticks([0.0, 0.3, 0.6])
    ax[2][1].set_yticklabels(
        ["0", "0.3", "0.6"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[2][1].set_xlabel("Number of simultaneously trained LoRA adapters")
    ax[2][1].text(
        0.95,
        0.95,
        "Model:13B",
        fontsize=10,
        va="top",
        ha="right",
        transform=ax[2][1].transAxes,
    )

    y_peft = [b13_k_time_peft[cnt - 1] for cnt in x]
    y_batchlora = [b13_k_time_lora[cnt - 1] for cnt in x]
    ax[2][2].plot(x, y_peft, color=c_2, marker="o")
    ax[2][2].plot(x, y_batchlora, color=c_4, marker="*")
    ax[2][2].set_ylim(9.9, 10.1)
    ax[2][2].set_yticklabels([])
    ax[2][2].set_xticks(ticks)
    ax[2][2].set_xticklabels(ticks_label)
    ax[2][2].tick_params(pad=7)
    ax[2][2].set_yticks([9.9, 10, 10.1])
    ax[2][2].set_yticklabels(
        ["9.9", "10", "10.1"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[2][2].text(
        0.95,
        0.95,
        "Model:13B",
        fontsize=10,
        va="top",
        ha="right",
        transform=ax[2][2].transAxes,
    )

    # # # # # # # #

    fig.legend(
        ncol=2,
        bbox_to_anchor=(0.75, 1.1),
        fancybox=False,
        framealpha=0.0,
        fontsize=16,
    )


    plt.savefig("batchlora_op_task.pdf", bbox_inches="tight", dpi=1000)
    return (
        ax,
        b13_k_time_lora,
        b13_k_time_peft,
        b13_kern_exec_time,
        b13_kern_launch_time,
        b13_peft_kern_exec_time,
        b13_peft_kern_launch_time,
        b13_peft_total_time,
        b13_total_time,
        b13_tp_m_s,
        b13_tp_p_s,
        b1_k_time_lora,
        b1_k_time_peft,
        b1_kern_exec_time,
        b1_kern_launch_time,
        b1_peft_kern_exec_time,
        b1_peft_kern_launch_time,
        b1_peft_total_time,
        b1_total_time,
        b1_tp_m_s,
        b1_tp_p_s,
        b7_k_time_lora,
        b7_k_time_peft,
        b7_kern_exec_time,
        b7_kern_launch_time,
        b7_peft_kern_exec_time,
        b7_peft_kern_launch_time,
        b7_peft_total_time,
        b7_total_time,
        b7_tp_m_s,
        b7_tp_p_s,
        base_b1,
        base_b13,
        base_b7,
        bt,
        c_1,
        c_2,
        c_3,
        c_4,
        fig,
        i,
        pt,
        ticks,
        ticks_label,
        x,
        y_batchlora,
        y_peft,
    )


if __name__ == "__main__":
    app.run()