import marimo

__generated_with = "0.9.17"
app = marimo.App(width="medium")


@app.cell
def __():
    import matplotlib
    import matplotlib.pyplot as plt
    import numpy as np
    return matplotlib, np, plt


@app.cell
def __(matplotlib, plt):
    matplotlib.rcParams["text.usetex"] = False
    plt.rcParams["font.family"] = "Times New Roman"
    plt.rcParams["font.size"] = 16
    return


@app.cell
def __(np, plt):
    x = np.arange(4)

    fig, ax = plt.subplots(
        figsize=(14, 14 / 2.2), ncols=4, nrows=3, layout="constrained"
    )

    c_1 = (139 / 255, 0 / 255, 0 / 255)
    c_2 = (0, 0, 0)
    c_3 = (191 / 255, 191 / 255, 191 / 255)
    c_4 = (230 / 255, 109 / 255, 104 / 255)

    total_token = 7473 * 512 * 10

    # A6000 单卡
    b1_tp_m_s = [
        2857.40228,
        3016.124377,
        3043.99588,
        3047.335256,
        3051.551977,
        3051.512532,
        3048.015064,
        3047.108509,
        3048.642661,
        3051.840965,
        3047.57159,
        3047.861865,
    ]
    b1_tp_p_s = [
        2857.389389,
        2842,
        2851,
        2847,
        2853,
        2841,
        2843,
        2851,
        2850,
        2851.3,
        2849,
        2852,
    ]

    b7_tp_m_s = [702.5522131, 716.7349242, 722.2261427, 725.5761517, 727.0030057]
    b7_tp_p_s = [702.6845352, 699.1034186, 701.3211972, 700.1283237, 700.1098232]

    b13_tp_m_s = [398.8387303, 403.5820717, 405.8601994]
    b13_tp_p_s = [398.7009553, 398.4052117, 399.1230098]

    # A6000 4卡
    b1_throughput_mLoRA = [
        4600.45,
        8664.91,
        10118.36,
        10184.44,
        10119,
        11157,
        11157,
        11530,
        11580,
        11580,
        11600,
        11600,
    ]
    b1_throughput_tp = [
        5752.14,
        5749.34,
        5756.78,
        5758.32,
        5753.14,
        5753.34,
        5756.78,
        5756.32,
        5758.14,
        5749.34,
        5753.78,
        5754.32,
    ]
    b1_throughput_fsdp = [
        6151.91,
        6141.73,
        6161.23,
        6153.93,
        6153.91,
        6146.73,
        6161.23,
        6151.93,
        6157.91,
        6143.73,
        6161.23,
        6153.93,
    ]
    b1_throughput_gpipe = [
        4599.87,
        4610.19,
        4592.17,
        4601.18,
        4598.87,
        4600.19,
        4593.17,
        4601.18,
        4599.87,
        4610.19,
        4592.17,
        4603.18,
    ]

    b7_throughput_mLoRA = [1274.87, 2250.46, 2362.69, 2363.89]
    b7_throughput_tp = [1614.18, 1610.26, 1620.07, 1613.34]
    b7_throughput_fsdp = [1695.37, 1705.97, 1686.05, 1693.45]
    b7_throughput_gpipe = [1284.27, 1273.89, 1272.14, 1279.64]

    b13_throughput_mLoRA = [723.21, 1280.54]
    b13_throughput_tp = [875, 877]
    b13_throughput_fsdp = [0, 0, 0, 0]  # for the error
    b13_throughput_gpipe = [723.21, 719.21]

    b70_throughput_mLoRA = [234, 291, 320, 318]
    b70_throughput_gpipe = [234.34, 234.32, 234.38, 234]

    # 3090 8卡
    b1_4090_mlora = [
        319.61,
        580.34,
        663.12,
        799.69,
        800.96,
        813.64,
        812.92,
        814.93,
    ]
    b1_4090_tp = [35.17, 35.33, 35.32, 35.32, 35.38, 35.34, 35.33, 35.34]
    b1_4090_fsdp = [62.79, 62.31, 60.03, 61.53, 62.79, 60.37, 61.23, 63.11]
    b1_4090_gpipe = [
        318.99,
        319.16,
        319.61,
        319.42,
        319.23,
        319.61,
        319.80,
        319.96,
    ]

    b7_4090_mlora = [576.79, 671.39, 702.91, 715.85]
    b7_4090_gpipe = [578.97, 581.46, 573.70, 580.84]
    b7_4090_fsdp = []
    b7_4090_tp = [11.93, 12.09, 11.86, 12.10]

    b13_4090_mlora = [524.57, 694.34]
    b13_4090_gpipe = [528.47, 522.06]
    b13_4090_fsdp = []
    b13_4090_tp = []
    # 绘制 A6000 4卡
    x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
    ticks = [2, 4, 6, 8, 10, 12]
    ticks_label = ["2", "4", "6", "8", "10", "12"]
    m_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_mLoRA]
    tp_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_tp]
    fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_fsdp]
    g_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_gpipe]

    ax[0][0].plot(x, g_avg_time, color=c_1, marker="v")
    ax[0][0].plot(x, fsdp_avg_time, color=c_3, marker="^")
    ax[0][0].plot(x, tp_avg_time, color=c_2, marker="o")
    ax[0][0].plot(x, m_avg_time, color=c_4, marker="*")
    ax[0][0].set_xticks(ticks)
    ax[0][0].set_xticklabels(ticks_label)

    x = [1, 2, 3, 4]
    ticks = [1, 2, 3, 4]
    ticks_label = ["1", "2", "3", "4"]

    m_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_mLoRA]
    tp_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_tp]
    fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_fsdp]
    g_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_gpipe]

    ax[0][1].plot(x, g_avg_time, color=c_1, marker="v", label="1F1B")
    ax[0][1].plot(x, fsdp_avg_time, color=c_3, marker="^", label="FSDP")
    ax[0][1].plot(x, tp_avg_time, color=c_2, marker="o", label="TP")
    ax[0][1].plot(x, m_avg_time, color=c_4, marker="*")
    ax[0][1].set_xticks(ticks)
    ax[0][1].set_xticklabels(ticks_label)

    x = [1, 2]
    ticks = [1, 2]
    ticks_label = ["1", "2"]

    m_avg_time = [total_token / tp / 60 / 60 for tp in b13_throughput_mLoRA]
    tp_avg_time = [total_token / tp / 60 / 60 for tp in b13_throughput_tp]
    g_avg_time = [total_token / tp / 60 / 60 for tp in b13_throughput_gpipe]

    ax[0][2].plot(x, g_avg_time, color=c_1, marker="v")
    ax[0][2].plot(x, tp_avg_time, color=c_2, marker="o")
    ax[0][2].plot(x, m_avg_time, color=c_4, marker="*")
    ax[0][2].set_xticks(ticks)
    ax[0][2].set_xticklabels(ticks_label)

    x = [1, 2, 3, 4]
    ticks = [1, 2, 3, 4]
    ticks_label = ["1", "2", "3", "4"]

    m_avg_time = [total_token / tp / 60 / 60 for tp in b70_throughput_mLoRA]
    g_avg_time = [total_token / tp / 60 / 60 for tp in b70_throughput_gpipe]

    ax[0][3].plot(x, g_avg_time, color=c_1, marker="v")
    ax[0][3].plot(x, m_avg_time, color=c_4, marker="*")
    ax[0][3].set_xticks(ticks)
    ax[0][3].set_xticklabels(ticks_label)

    ## END

    # 绘制 3090 8 卡
    x = [1, 2, 3, 4, 5, 6, 7, 8]
    ticks = [2, 4, 6, 8]
    ticks_label = ["2", "4", "6", "8"]
    m_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_mlora]
    tp_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_tp]
    fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_fsdp]
    g_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_gpipe]

    ax[1][0].plot(x, g_avg_time, color=c_1, marker="v")
    ax[1][0].plot(x, fsdp_avg_time, color=c_3, marker="^")
    ax[1][0].plot(x, tp_avg_time, color=c_2, marker="o")
    ax[1][0].plot(x, m_avg_time, color=c_4, marker="*")
    ax[1][0].set_xticks(ticks)
    ax[1][0].set_xticklabels(ticks_label)

    x = [1, 2, 3, 4]
    ticks = [1, 2, 3, 4]
    ticks_label = ["1", "2", "3", "4"]
    m_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_mlora]
    tp_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_tp]
    fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_fsdp]
    g_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_gpipe]

    ax[1][1].plot(x, g_avg_time, color=c_1, marker="v")
    # ax[1][1].plot([1], fsdp_avg_time, color=c_3, marker="^")
    # ax[1][1].plot(x, tp_avg_time, color=c_2, marker="o")
    ax[1][1].plot(x, m_avg_time, color=c_4, marker="*")
    ax[1][1].set_xticks(ticks)
    ax[1][1].set_xticklabels(ticks_label)

    x = [1, 2]
    ticks = [1, 2]
    ticks_label = ["1", "2"]
    m_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_mlora]
    tp_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_tp]
    fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_fsdp]
    g_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_gpipe]

    ax[1][2].plot(x, g_avg_time, color=c_1, marker="v")
    # ax[1][2].plot([1], fsdp_avg_time, color=c_3, marker="^")
    # ax[1][2].plot(x, tp_avg_time, color=c_2, marker="o")
    ax[1][2].plot(x, m_avg_time, color=c_4, marker="*")
    ax[1][2].set_xticks(ticks)
    ax[1][2].set_xticklabels(ticks_label)

    x = [1, 2]
    ticks = [1, 2]
    ticks_label = ["1", "2"]

    ax[1][3].set_xticks(ticks)
    ax[1][3].set_xticklabels(ticks_label)

    # END

    # 绘制 A6000 单卡
    x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
    ticks = [2, 4, 6, 8, 10, 12]
    ticks_label = ["2", "4", "6", "8", "10", "12"]
    m_avg_time = [total_token / tp / 60 / 60 for tp in b1_tp_m_s]
    p_avg_time = [total_token / tp / 60 / 60 for tp in b1_tp_p_s]
    ax[2][0].plot(x, p_avg_time, color=c_2, marker="^", label="PEFT")
    ax[2][0].plot(x, m_avg_time, color=c_4, marker="*", label="mLoRA")
    ax[2][0].set_xticks(ticks)
    ax[2][0].set_xticklabels(ticks_label)

    x = [1, 2, 3, 4, 5]
    ticks = [1, 2, 3, 4, 5]
    ticks_label = ["1", "2", "3", "4", "5"]
    m_avg_time = [total_token / tp / 60 / 60 for tp in b7_tp_m_s]
    p_avg_time = [total_token / tp / 60 / 60 for tp in b7_tp_p_s]
    ax[2][1].plot(x, p_avg_time, color=c_2, marker="^")
    ax[2][1].plot(x, m_avg_time, color=c_4, marker="*")
    ax[2][1].set_xticks(ticks)
    ax[2][1].set_xticklabels(ticks_label)

    x = [1, 2, 3]
    ticks = [1, 2, 3]
    ticks_label = ["1", "2", "3"]
    m_avg_time = [total_token / tp / 60 / 60 for tp in b13_tp_m_s]
    p_avg_time = [total_token / tp / 60 / 60 for tp in b13_tp_p_s]
    ax[2][2].plot(x, p_avg_time, color=c_2, marker="^")
    ax[2][2].plot(x, m_avg_time, color=c_4, marker="*")
    ax[2][2].set_xticks(ticks)
    ax[2][2].set_xticklabels(ticks_label)


    x = [1, 2]
    ticks = [1, 2]
    ticks_label = ["1", "2"]

    ax[2][3].set_xticks(ticks)
    ax[2][3].set_xticklabels(ticks_label)
    ## END


    ax[0][0].set_ylim(0, 3)
    ax[0][1].set_ylim(0, 10)
    ax[0][2].set_ylim(0, 20)
    ax[0][3].set_ylim(0, 60)

    ax[1][0].set_ylim(0, 400)
    ax[1][1].set_ylim(0, 40)
    ax[1][2].set_ylim(0, 40)
    ax[1][3].set_ylim(0, 40)

    ax[2][0].set_ylim(3, 4)
    ax[2][1].set_ylim(13, 16)
    ax[2][2].set_ylim(25, 27)
    ax[2][3].set_ylim(25, 27)


    ax[0][2].set_xlim(0.8, 3 - 0.8)

    ax[1][0].set_xlim(0.5, 9 - 0.5)
    ax[1][1].set_xlim(0.5, 5 - 0.5)
    ax[1][2].set_xlim(0.8, 3 - 0.8)

    ax[2][2].set_xlim(0.8, 4 - 0.8)

    ax[0][0].set_title("(a) 1.1B A6000×4", fontsize=16)
    ax[0][1].set_title("(b) 7B A6000×4", fontsize=16)
    ax[0][2].set_title("(c) 13B A6000×4", fontsize=16)
    ax[0][3].set_title("(d) 70B A6000×4", fontsize=16)

    ax[1][0].set_title("(e) 1.1B 3090×8", fontsize=16)
    ax[1][1].set_title("(f) 7B 3090×8", fontsize=16)
    ax[1][2].set_title("(g) 13B 3090×8", fontsize=16)
    ax[1][3].set_title("(h) 70B 3090×8", fontsize=16)

    ax[2][0].set_title("(i) 1.1B A6000", fontsize=16)
    ax[2][1].set_title("(j) 7B A6000", fontsize=16)
    ax[2][2].set_title("(k) 13B A6000", fontsize=16)
    ax[2][3].set_title("(l) 70B A6000", fontsize=16)

    ax[0][0].set_ylabel("Average task\ncompletion time (h)")
    ax[1][0].set_ylabel("Average task\ncompletion time (h)")
    ax[2][0].set_ylabel("Average task\ncompletion time (h)")

    ax[0][2].text(
        0.9,
        0.8,
        "FSDP : OOM",
        fontsize=12,
        va="bottom",
        ha="right",
        transform=ax[0][2].transAxes,
        color=c_4,
    )

    ax[0][3].text(
        0.9,
        0.1,
        "FSDP : OOM\nTP : OOM",
        fontsize=12,
        va="bottom",
        ha="right",
        transform=ax[0][3].transAxes,
        color=c_4,
    )

    ax[1][1].text(
        0.9,
        0.7,
        "FSDP : OOM\nTP : about one month",
        fontsize=12,
        va="bottom",
        ha="right",
        transform=ax[1][1].transAxes,
        color=c_4,
    )

    ax[1][2].text(
        0.9,
        0.7,
        "FSDP : OOM\nTP : OOM",
        fontsize=12,
        va="bottom",
        ha="right",
        transform=ax[1][2].transAxes,
        color=c_4,
    )

    ax[1][3].text(
        0.5,
        0.5,
        "FSDP : OOM\nTP : OOM\n1F1B : OOM\nmLoRA : OOM",
        fontsize=12,
        va="center",
        ha="center",
        transform=ax[1][3].transAxes,
        color=c_4,
    )


    ax[2][3].text(
        0.5,
        0.5,
        "PEFT : OOM\nmLoRA : OOM",
        fontsize=12,
        va="center",
        ha="center",
        transform=ax[2][3].transAxes,
        color=c_4,
    )


    fig.legend(
        ncol=5,
        bbox_to_anchor=(0.75, 1.05),
        fancybox=False,
        framealpha=0.0,
        fontsize=14,
    )


    ax[0][0].arrow(5, 0.5, 0.5, 0.3, width=0.01, head_width=0.1)
    ax[0][0].text(
        0.38,
        0.07,
        "enable BatchLoRA",
        fontsize=10,
        va="bottom",
        ha="right",
        transform=ax[0][0].transAxes,
        color=c_4,
    )


    fig.supxlabel(
        "Number of simultaneously trained LoRA adapters",
        fontsize=16,
        y=-0.03,
        ha="center",
        va="bottom",
    )


    plt.savefig("end-to-end.pdf", bbox_inches="tight", dpi=1000)
    return (
        ax,
        b13_4090_fsdp,
        b13_4090_gpipe,
        b13_4090_mlora,
        b13_4090_tp,
        b13_throughput_fsdp,
        b13_throughput_gpipe,
        b13_throughput_mLoRA,
        b13_throughput_tp,
        b13_tp_m_s,
        b13_tp_p_s,
        b1_4090_fsdp,
        b1_4090_gpipe,
        b1_4090_mlora,
        b1_4090_tp,
        b1_throughput_fsdp,
        b1_throughput_gpipe,
        b1_throughput_mLoRA,
        b1_throughput_tp,
        b1_tp_m_s,
        b1_tp_p_s,
        b70_throughput_gpipe,
        b70_throughput_mLoRA,
        b7_4090_fsdp,
        b7_4090_gpipe,
        b7_4090_mlora,
        b7_4090_tp,
        b7_throughput_fsdp,
        b7_throughput_gpipe,
        b7_throughput_mLoRA,
        b7_throughput_tp,
        b7_tp_m_s,
        b7_tp_p_s,
        c_1,
        c_2,
        c_3,
        c_4,
        fig,
        fsdp_avg_time,
        g_avg_time,
        m_avg_time,
        p_avg_time,
        ticks,
        ticks_label,
        total_token,
        tp_avg_time,
        x,
    )


@app.cell
def __():
    return


if __name__ == "__main__":
    app.run()