import marimo __generated_with = "0.7.0" app = marimo.App(width="medium") @app.cell def __(): import numpy as np import matplotlib.pyplot as plt import random from matplotlib.colors import LinearSegmentedColormap return LinearSegmentedColormap, np, plt, random @app.cell def __(plt): plt.rcParams['font.family'] = 'Times New Roman' plt.rcParams['font.size'] = 16 return @app.cell(hide_code=True) def __(LinearSegmentedColormap, np, plt): # 给定的 D 值列表 D_values = [4, 8] # 创建一个图形和子图 fig, axs = plt.subplots(1, 2, figsize=(7, 3.7), dpi=400) axs = axs.flatten() # 将 axs 数组展平,方便迭代 def smooth_curve(points, factor=0.8): smoothed_points = [] for point in points: if smoothed_points: previous = smoothed_points[-1] # 上一个节点*0.8+当前节点*0.2 smoothed_points.append(previous * factor + point * (1 - factor)) else: # 添加point smoothed_points.append(point) return smoothed_points c_1 = (230 / 255, 241 / 255, 243 / 255) c_2 = (0, 0, 0) c_3 = (255 / 255, 223 / 255, 146 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) for idx, D in enumerate(D_values): idx = idx matrix = np.zeros((D, D)) for N in range(1, D + 1): for L in range(1, D + 1): value = (D + N - 1 - L * N) / (D + N - 1) matrix[N - 1, L - 1] = max(0, value) colors = [ (1, 1, 1), # 黑色 (240 / 255, 175 / 255, 175 / 255), # 浅红色 (230 / 255, 109 / 255, 104 / 255), # 深红色 (100 / 255, 0 / 255, 0 / 255), ] # 红色 cmap = LinearSegmentedColormap.from_list("custom_red_black", colors, N=256) step = max(int(2 ** (idx - 1)), 1) start = int(2 ** (idx - 1) - 1) axs[idx].set_xticks( range(start, D + 1, step), range(start + 1, D + 2, step) ) axs[idx].set_yticks( range(start, D + 1, step), range(start + 1, D + 2, step) ) cax = axs[idx].imshow(matrix, cmap=cmap, origin="lower") # 在每个单元格添加数值标签 if idx < 1: for i in range(D): for j in range(D): color = "white" if matrix[i, j] >= 0.6 else "black" word = f"{matrix[i, j]:.2f}" if matrix[i, j] != 0 else "" text = axs[idx].text( j, i, word, ha="center", va="center", color=color, fontsize=14, ) axs[0].set_title("Number of GPUs = 4", fontsize=16) axs[1].set_title("Number of GPUs = 8", fontsize=16) axs[0].set_ylabel("Number of micro-batches", fontsize=16) axs[1].set_ylabel("Number of micro-batches", fontsize=16) axs[0].set_xlabel( " Number of simultaneously trained LoRA adapters", fontsize=16, ) plt.tight_layout() colorbar = fig.colorbar( cax, ax=axs, location="right", pad=0.01 ) colorbar.set_label("Bubble ratio", fontsize=16) plt.show() # plt.savefig("bubble.pdf", bbox_inches="tight", dpi=1000) return ( D, D_values, L, N, axs, c_1, c_2, c_3, c_4, cax, cmap, color, colorbar, colors, fig, i, idx, j, matrix, smooth_curve, start, step, text, value, word, ) if __name__ == "__main__": app.run()