138 lines
3.8 KiB
Python
138 lines
3.8 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.7.0"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell
|
|
def __():
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import random
|
|
from matplotlib.colors import LinearSegmentedColormap
|
|
return LinearSegmentedColormap, np, plt, random
|
|
|
|
|
|
@app.cell
|
|
def __(plt):
|
|
plt.rcParams['font.family'] = 'Times New Roman'
|
|
plt.rcParams['font.size'] = 16
|
|
return
|
|
|
|
|
|
@app.cell(hide_code=True)
|
|
def __(LinearSegmentedColormap, np, plt):
|
|
# 给定的 D 值列表
|
|
D_values = [4, 8]
|
|
# 创建一个图形和子图
|
|
fig, axs = plt.subplots(1, 2, figsize=(7, 3.7), dpi=400)
|
|
axs = axs.flatten() # 将 axs 数组展平,方便迭代
|
|
|
|
|
|
def smooth_curve(points, factor=0.8):
|
|
smoothed_points = []
|
|
for point in points:
|
|
if smoothed_points:
|
|
previous = smoothed_points[-1]
|
|
# 上一个节点*0.8+当前节点*0.2
|
|
smoothed_points.append(previous * factor + point * (1 - factor))
|
|
else:
|
|
# 添加point
|
|
smoothed_points.append(point)
|
|
return smoothed_points
|
|
|
|
|
|
c_1 = (230 / 255, 241 / 255, 243 / 255)
|
|
c_2 = (0, 0, 0)
|
|
c_3 = (255 / 255, 223 / 255, 146 / 255)
|
|
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
|
|
|
for idx, D in enumerate(D_values):
|
|
idx = idx
|
|
matrix = np.zeros((D, D))
|
|
|
|
for N in range(1, D + 1):
|
|
for L in range(1, D + 1):
|
|
value = (D + N - 1 - L * N) / (D + N - 1)
|
|
matrix[N - 1, L - 1] = max(0, value)
|
|
|
|
colors = [
|
|
(1, 1, 1), # 黑色
|
|
(240 / 255, 175 / 255, 175 / 255), # 浅红色
|
|
(230 / 255, 109 / 255, 104 / 255), # 深红色
|
|
(100 / 255, 0 / 255, 0 / 255),
|
|
] # 红色
|
|
cmap = LinearSegmentedColormap.from_list("custom_red_black", colors, N=256)
|
|
step = max(int(2 ** (idx - 1)), 1)
|
|
start = int(2 ** (idx - 1) - 1)
|
|
axs[idx].set_xticks(
|
|
range(start, D + 1, step), range(start + 1, D + 2, step)
|
|
)
|
|
axs[idx].set_yticks(
|
|
range(start, D + 1, step), range(start + 1, D + 2, step)
|
|
)
|
|
cax = axs[idx].imshow(matrix, cmap=cmap, origin="lower")
|
|
|
|
# 在每个单元格添加数值标签
|
|
if idx < 1:
|
|
for i in range(D):
|
|
for j in range(D):
|
|
color = "white" if matrix[i, j] >= 0.6 else "black"
|
|
word = f"{matrix[i, j]:.2f}" if matrix[i, j] != 0 else ""
|
|
text = axs[idx].text(
|
|
j,
|
|
i,
|
|
word,
|
|
ha="center",
|
|
va="center",
|
|
color=color,
|
|
fontsize=14,
|
|
)
|
|
axs[0].set_title("Number of GPUs = 4", fontsize=16)
|
|
axs[1].set_title("Number of GPUs = 8", fontsize=16)
|
|
axs[0].set_ylabel("Number of micro-batches", fontsize=16)
|
|
axs[1].set_ylabel("Number of micro-batches", fontsize=16)
|
|
axs[0].set_xlabel(
|
|
" Number of simultaneously trained LoRA adapters",
|
|
fontsize=16,
|
|
)
|
|
|
|
plt.tight_layout()
|
|
colorbar = fig.colorbar(
|
|
cax, ax=axs, location="right", pad=0.01
|
|
)
|
|
colorbar.set_label("Bubble ratio", fontsize=16)
|
|
plt.show()
|
|
# plt.savefig("bubble.pdf", bbox_inches="tight", dpi=1000)
|
|
return (
|
|
D,
|
|
D_values,
|
|
L,
|
|
N,
|
|
axs,
|
|
c_1,
|
|
c_2,
|
|
c_3,
|
|
c_4,
|
|
cax,
|
|
cmap,
|
|
color,
|
|
colorbar,
|
|
colors,
|
|
fig,
|
|
i,
|
|
idx,
|
|
j,
|
|
matrix,
|
|
smooth_curve,
|
|
start,
|
|
step,
|
|
text,
|
|
value,
|
|
word,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|