import marimo

__generated_with = "0.9.17"
app = marimo.App(width="medium")


@app.cell
def __():
    import matplotlib.pyplot as plt
    import numpy as np
    return np, plt


@app.cell
def __(plt):
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 16
    return


@app.cell(hide_code=True)
def __(np, plt):
    x = np.arange(4)

    fig, ax = plt.subplots(
        figsize=(7, 3.5), ncols=3, nrows=2, layout="constrained"
    )

    c_1 = (230 / 255, 241 / 255, 243 / 255)
    c_2 = (0, 0, 0)
    c_3 = (255 / 255, 223 / 255, 146 / 255)
    c_4 = (230 / 255, 109 / 255, 104 / 255)


    x = [1, 2, 3, 4, 5, 6, 7, 8]
    ticks = [1, 2, 3, 4, 5, 6, 7, 8]
    ticks_label = ["1", "2", "3", "4", "5", "6", "7", "8"]
    y_torch = [
        20.29159868,
        39.16359761,
        55.65218289,
        73.73416966,
        88.12034672,
        104.6562161,
        124.9680397,
        138.4435594,
    ]
    y_batchlora = [
        24.05060625,
        36.4906544,
        49.58459261,
        62.96516142,
        75.19585842,
        86.59812198,
        104.7908515,
        116.5208559,
    ]
    y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
    print(y_imporve)
    ax[0][0].plot(x, y_torch, color=c_2, label="Operator without Graph Pruning", marker="o")
    ax[0][0].plot(
        x, y_batchlora, color=c_4, label="Operator with Graph Pruning", marker="*"
    )
    ax[0][0].set_ylim(0, 170)
    ax[0][0].set_xticks(ticks)
    ax[0][0].set_xticklabels(ticks_label)
    ax[0][0].set_yticks([0, 50, 100, 150])
    ax[0][0].set_yticklabels(
        ["0", "50", "100", "150"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[0][0].set_ylabel("Time (us)")
    ax[0][0].tick_params(pad=7)
    ax[0][0].text(
        0.95,
        0.02,
        "Model:1.1B",
        fontsize=14,
        va="bottom",
        ha="right",
        transform=ax[0][0].transAxes,
    )
    # iax = ax[0][0].twinx()
    # iax.plot(x, y_imporve, color=c_3, label="Improve", linestyle="dashdot")
    # iax.set_ylim(-20, 20)
    # iax.set_yticks([-20, -10, 0, 10, 20])
    # iax.set_yticklabels([],
    #                     rotation=-90, ha="center", va="center")
    # iax.tick_params(pad=7)

    x = [1, 2, 3, 4, 5]
    ticks = [1, 2, 3, 4, 5]
    ticks_label = ["1", "2", "3", "4", "5"]
    y_torch = [21.9846773, 40.29510077, 59.41132316, 77.31547579, 95.81772611]
    y_batchlora = [23.42831809, 37.3211666, 51.31030688, 65.47136931, 79.53268476]
    y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
    ax[0][1].plot(x, y_torch, color=c_2, marker="o")
    ax[0][1].plot(x, y_batchlora, color=c_4, marker="*")
    ax[0][1].set_ylim(0, 170)
    ax[0][1].set_xticks(ticks)
    ax[0][1].set_xticklabels(ticks_label)
    ax[0][1].set_yticks([0, 50, 100, 150])
    ax[0][1].set_yticklabels(
        ["0", "50", "100", "150"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[0][1].tick_params(pad=7)
    # iax = ax[0][1].twinx()
    # iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
    # iax.set_ylim(-20, 20)
    # iax.set_yticks([-20, -10, 0, 10, 20])
    # iax.set_yticklabels([],
    #                     rotation=-90, ha="center", va="center")
    # iax.tick_params(pad=7)
    ax[0][1].text(
        0.95,
        0.02,
        "Model:7B",
        fontsize=14,
        va="bottom",
        ha="right",
        transform=ax[0][1].transAxes,
    )


    x = [1, 2, 3]
    ticks = [1, 2, 3]
    ticks_label = ["1", "2", "3"]
    y_torch = [22.37562835, 42.73959994, 61.96534634]
    y_batchlora = [22.80589938, 39.08431157, 54.03284915]
    y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
    ax[0][2].plot(x, y_torch, color=c_2, marker="o")
    ax[0][2].plot(x, y_batchlora, color=c_4, marker="*")
    ax[0][2].set_ylim(0, 85)
    ax[0][2].set_xticks(ticks)
    ax[0][2].set_xticklabels(ticks_label)
    ax[0][2].set_yticks([0, 25, 50, 75])
    ax[0][2].set_yticklabels(
        ["0", "25", "50", "75"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[0][2].tick_params(pad=7)
    # iax = ax[0][2].twinx()
    # iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
    # iax.set_ylim(-20, 20)
    # iax.set_yticks([-20, -10, 0, 10, 20])
    # iax.set_yticklabels(["-20", "-10", "0", "10", "20"],
    #                     rotation=-90, ha="center", va="center")
    # iax.tick_params(pad=7)
    ax[0][2].text(
        0.95,
        0.02,
        "Model:13B",
        fontsize=14,
        va="bottom",
        ha="right",
        transform=ax[0][2].transAxes,
    )


    x = [1, 2, 3, 4, 5, 6, 7, 8]
    ticks = [1, 2, 3, 4, 5, 6, 7, 8]
    ticks_label = ["1", "2", "3", "4", "5", "6", "7", "8"]
    y_torch = [
        3.501786362,
        5.201629498,
        6.926035673,
        8.653747242,
        10.38349666,
        12.11389446,
        13.84254159,
        15.57186355,
    ]
    y_batchlora = [
        3.393746125,
        4.998805098,
        6.571097296,
        8.141429216,
        9.714332745,
        11.28814712,
        12.85960523,
        14.43241727,
    ]
    y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
    ax[1][0].plot(x, y_torch, color=c_2, marker="o")
    ax[1][0].plot(x, y_batchlora, color=c_4, marker="*")
    ax[1][0].set_ylim(0, 18)
    ax[1][0].set_xticks(ticks)
    ax[1][0].set_xticklabels(ticks_label)
    ax[1][0].set_ylabel("Peak Memory (GB)")
    ax[1][0].set_yticks([0, 5, 10, 15])
    ax[1][0].set_yticklabels(
        ["0", "5", "10", "15"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[1][0].tick_params(pad=7)
    # iax = ax[1][0].twinx()
    # iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
    # iax.set_ylim(0, 10)
    # iax.set_yticks([0, 5, 10])
    # iax.set_yticklabels([],
    #                     rotation=-90, ha="center", va="center")
    # iax.tick_params(pad=7)
    ax[1][0].text(
        0.95,
        0.02,
        "Model:1.1B",
        fontsize=14,
        va="bottom",
        ha="right",
        transform=ax[1][0].transAxes,
    )


    x = [1, 2, 3, 4, 5]
    ticks = [1, 2, 3, 4, 5]
    ticks_label = ["1", "2", "3", "4", "5"]
    y_torch = [11.09702569, 14.20903317, 17.3902113, 20.59260555, 23.79299152]
    y_batchlora = [10.8060544, 13.69216517, 16.51687164, 19.33916381, 22.15532633]
    y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
    ax[1][1].plot(x, y_torch, color=c_2, marker="o")
    ax[1][1].plot(x, y_batchlora, color=c_4, marker="*")
    ax[1][1].set_ylim(10, 27)
    ax[1][1].set_xticks(ticks)
    ax[1][1].set_xticklabels(ticks_label)
    ax[1][1].set_yticks([10, 15, 20, 25])
    ax[1][1].set_yticklabels(
        ["10", "15", "20", "25"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[1][1].tick_params(pad=7)
    # iax = ax[1][1].twinx()
    # iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
    # iax.set_ylim(0, 10)
    # iax.set_yticks([0, 5, 10])
    # iax.set_yticklabels([],
    #                     rotation=-90, ha="center", va="center")
    # iax.tick_params(pad=7)
    ax[1][1].text(
        0.95,
        0.02,
        "Model:7B",
        fontsize=14,
        va="bottom",
        ha="right",
        transform=ax[1][1].transAxes,
    )

    x = [1, 2, 3]
    ticks = [1, 2, 3]
    ticks_label = ["1", "2", "3"]
    y_torch = [18.52022991, 22.66454649, 26.87934514]
    y_batchlora = [18.16663067, 22.02296134, 25.78340335]
    y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
    ax[1][2].plot(x, y_torch, color=c_2, marker="o")
    ax[1][2].plot(x, y_batchlora, color=c_4, marker="*")
    ax[1][2].set_ylim(15, 32)
    ax[1][2].set_xticks(ticks)
    ax[1][2].set_xticklabels(ticks_label)
    ax[1][2].set_yticks([15, 20, 25, 30])
    ax[1][2].set_yticklabels(
        ["15", "20", "25", "30"], rotation=90, ha="center", va="center", fontsize=12
    )
    ax[1][2].tick_params(pad=7)
    # iax = ax[1][2].twinx()
    # iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
    # iax.set_ylim(0, 10)
    # iax.set_yticks([0, 5, 10])
    # iax.set_yticklabels(["0", "5", "10"],
    #                     rotation=-90, ha="center", va="center")
    # iax.tick_params(pad=7)
    # iax.set_ylabel("Percentage Increase")
    ax[1][2].text(
        0.95,
        0.02,
        "Model:13B",
        fontsize=14,
        va="bottom",
        ha="right",
        transform=ax[1][2].transAxes,
    )

    ax[1][1].set_xlabel("Number of simultaneously trained LoRA adapters")

    ax[0][1].set_title(
        "(a) The average time of forward and backward in the LoRA Operator",
        fontsize=16,
    )
    ax[1][1].set_title("(b) The peak memory used of LoRA Operator", fontsize=16)

    fig.legend(
        ncol=3,
        bbox_to_anchor=(1, 1.12),
        fancybox=False,
        framealpha=0.0,
        fontsize=14,
    )

    plt.savefig("batchlora_op_cmp.pdf", bbox_inches="tight", dpi=1000)
    return (
        ax,
        c_1,
        c_2,
        c_3,
        c_4,
        fig,
        ticks,
        ticks_label,
        x,
        y_batchlora,
        y_imporve,
        y_torch,
    )


if __name__ == "__main__":
    app.run()