import marimo

__generated_with = "0.9.17"
app = marimo.App(width="medium")


@app.cell
def __():
    import matplotlib.pyplot as plt
    from sklearn.linear_model import LinearRegression
    import numpy as np
    return LinearRegression, np, plt


@app.cell
def __(plt):
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 16
    return


@app.cell
def __(np, plt):
    x = np.arange(4)

    fig, ax = plt.subplots(figsize=(7, 4/2), ncols=3, nrows=1, layout="constrained", dpi=300)

    c_1 = (230 / 255, 241 / 255, 243 / 255)
    c_2 = (75 / 255, 116 / 155, 178 / 255)
    c_3 = (255 / 255, 223 / 255, 146 / 255)
    c_4 = (230 / 255, 109 / 255, 104 / 255)



    ax[0].set_xlabel("Rank of LoRA adapters", fontsize=14)
    ax[1].set_xlabel("Rank of LoRA adapters", fontsize=14)
    ax[2].set_xlabel("Rank of LoRA adapters", fontsize=14)

    ax[0].set_ylabel("Throughput (tokens/s)", fontsize=12)

    y_0 = [10525.52, 10467.51, 10315.85, 10286.17]
    x_0 = [4, 8, 16, 32]

    y_1 = [2309.40, 2258.73, 2252.43, 2242.80]
    x_1 = [16, 32, 64, 128]

    y_2 = [1245.79, 1244.60, 1224.91, 1207.40]
    x_2 = [16, 32, 64, 128]

    ax[0].plot(x_0, y_0, "s-", color=c_4, label="LoRAPP")
    ax[0].set_ylim(0, 11000)
    ax[0].set_xticks(x_0)
    ax[0].set_yticks([0, 5000, 10000])
    ax[0].set_yticklabels(
        ["0", "5k", "10k"], rotation=90, ha="center", va="center"
    )
    ax[0].tick_params(pad=7)

    ax[1].plot(x_1, y_1, "s-", color=c_4)
    ax[1].set_ylim(0, 11000)
    ax[1].set_xticks([32, 64, 128])
    ax[1].set_yticks([0, 5000, 10000])
    ax[1].set_yticklabels(
        ["0", "5k", "10k"], rotation=90, ha="center", va="center"
    )
    ax[1].tick_params(pad=7)

    ax[2].plot(x_2, y_2, "s-", color=c_4)
    ax[2].set_ylim(0, 11000)
    ax[2].set_xticks([32, 64, 128])
    ax[2].set_yticks([0, 5000, 10000])
    ax[2].set_yticklabels(
        ["0", "5k", "10k"], rotation=90, ha="center", va="center"
    )
    ax[2].tick_params(pad=7)

    ax[0].set_title("(a) TinyLlama-1.1B", fontsize=14)
    ax[1].set_title("(b) Llama-2-7B", fontsize=14)
    ax[2].set_title("(c) Llama-2-13B", fontsize=14)

    # plt.show()
    plt.savefig("adaptability.pdf", bbox_inches="tight", dpi=1000)
    return ax, c_1, c_2, c_3, c_4, fig, x, x_0, x_1, x_2, y_0, y_1, y_2


if __name__ == "__main__":
    app.run()