import marimo

__generated_with = "0.7.0"
app = marimo.App(width="medium")


@app.cell
def __():
    import numpy as np
    import matplotlib.pyplot as plt
    import random
    from matplotlib.colors import LinearSegmentedColormap
    return LinearSegmentedColormap, np, plt, random


@app.cell
def __(plt):
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 16
    return


@app.cell(hide_code=True)
def __(LinearSegmentedColormap, np, plt):
    # 给定的 D 值列表
    D_values = [4, 8]
    # 创建一个图形和子图
    fig, axs = plt.subplots(1, 2, figsize=(7, 3.7), dpi=400)
    axs = axs.flatten()  # 将 axs 数组展平,方便迭代


    def smooth_curve(points, factor=0.8):
        smoothed_points = []
        for point in points:
            if smoothed_points:
                previous = smoothed_points[-1]
                # 上一个节点*0.8+当前节点*0.2
                smoothed_points.append(previous * factor + point * (1 - factor))
            else:
                # 添加point
                smoothed_points.append(point)
        return smoothed_points


    c_1 = (230 / 255, 241 / 255, 243 / 255)
    c_2 = (0, 0, 0)
    c_3 = (255 / 255, 223 / 255, 146 / 255)
    c_4 = (230 / 255, 109 / 255, 104 / 255)

    for idx, D in enumerate(D_values):
        idx = idx
        matrix = np.zeros((D, D))

        for N in range(1, D + 1):
            for L in range(1, D + 1):
                value = (D + N - 1 - L * N) / (D + N - 1)
                matrix[N - 1, L - 1] = max(0, value)

        colors = [
            (1, 1, 1),  # 黑色
            (240 / 255, 175 / 255, 175 / 255),  # 浅红色
            (230 / 255, 109 / 255, 104 / 255),  # 深红色
            (100 / 255, 0 / 255, 0 / 255),
        ]  # 红色
        cmap = LinearSegmentedColormap.from_list("custom_red_black", colors, N=256)
        step = max(int(2 ** (idx - 1)), 1)
        start = int(2 ** (idx - 1) - 1)
        axs[idx].set_xticks(
            range(start, D + 1, step), range(start + 1, D + 2, step)
        )
        axs[idx].set_yticks(
            range(start, D + 1, step), range(start + 1, D + 2, step)
        )
        cax = axs[idx].imshow(matrix, cmap=cmap, origin="lower")

        # 在每个单元格添加数值标签
        if idx < 1:
            for i in range(D):
                for j in range(D):
                    color = "white" if matrix[i, j] >= 0.6 else "black"
                    word = f"{matrix[i, j]:.2f}" if matrix[i, j] != 0 else ""
                    text = axs[idx].text(
                        j,
                        i,
                        word,
                        ha="center",
                        va="center",
                        color=color,
                        fontsize=14,
                    )
    axs[0].set_title("Number of GPUs = 4", fontsize=16)
    axs[1].set_title("Number of GPUs = 8", fontsize=16)
    axs[0].set_ylabel("Number of micro-batches", fontsize=16)
    axs[1].set_ylabel("Number of micro-batches", fontsize=16)
    axs[0].set_xlabel(
        "                                             Number of simultaneously trained LoRA adapters",
        fontsize=16,
    )

    plt.tight_layout()
    colorbar = fig.colorbar(
        cax, ax=axs, location="right", pad=0.01
    ) 
    colorbar.set_label("Bubble ratio", fontsize=16) 
    plt.show()
    # plt.savefig("bubble.pdf", bbox_inches="tight", dpi=1000)
    return (
        D,
        D_values,
        L,
        N,
        axs,
        c_1,
        c_2,
        c_3,
        c_4,
        cax,
        cmap,
        color,
        colorbar,
        colors,
        fig,
        i,
        idx,
        j,
        matrix,
        smooth_curve,
        start,
        step,
        text,
        value,
        word,
    )


if __name__ == "__main__":
    app.run()