import marimo

__generated_with = "0.9.10"
app = marimo.App(width="medium")


@app.cell
def __():
    import matplotlib
    import matplotlib.pyplot as plt
    import numpy as np

    matplotlib.rcParams["text.usetex"] = False
    plt.rcParams["font.family"] = "Times New Roman"
    plt.rcParams["font.size"] = 16

    c_1 = (139 / 255, 0 / 255, 0 / 255)
    c_2 = (0, 0, 0)
    c_3 = (191 / 255, 191 / 255, 191 / 255)
    c_4 = (230 / 255, 109 / 255, 104 / 255)

    total_token = 7473 * 512 * 10

    pp = [234.34, 234.32, 234.38, 234]
    mlora = [234, 291, 320, 318]

    pp_avg_time = [total_token / tp / 60 / 60 for tp in pp]
    mlora_avg_time = [total_token / tp / 60 / 60 for tp in mlora]

    fig, ax = plt.subplots(figsize=(5.5, 2.5), constrained_layout=True)

    x = [1, 2, 3, 4]

    ax.plot(x, pp_avg_time, color=c_1, marker="v", label="1F1B")
    ax.plot(x, mlora_avg_time, color=c_4, marker="*", label="mLoRA")
    ax.set_xticks(x)
    ax.set_xticklabels(["1", "2", "3", "4"])
    ax.set_yticks([0, 40, 60])
    ax.set_yticklabels(["0", "40", "60"])
    ax.set_ylim(0, 60 + 1)

    ax.set_title("(a) 70B A6000×4", fontsize=16)

    ax.set_ylabel("Average task\ncompletion time (h)")

    ax.set_xlabel("Number of simultaneously trained LoRA adapters", fontsize=16)

    ax.text(
        0.95,
        0.05,
        "FSDP : OOM\nTP : OOM",
        fontsize=12,
        va="bottom",
        ha="right",
        transform=ax.transAxes,
        color=c_4,
        style="italic",
    )

    fig.legend(
        ncol=2,
        bbox_to_anchor=(0.75, 0.4),
        fancybox=False,
        framealpha=0.0,
        fontsize=14,
    )

    # plt.savefig("end_to_end_large_model.pdf", bbox_inches="tight", dpi=1000, format="pdf")
    return (
        ax,
        c_1,
        c_2,
        c_3,
        c_4,
        fig,
        matplotlib,
        mlora,
        mlora_avg_time,
        np,
        plt,
        pp,
        pp_avg_time,
        total_token,
        x,
    )


if __name__ == "__main__":
    app.run()