paper_note/mlora/bubble.py
2025-03-05 20:38:41 +08:00

138 lines
3.8 KiB
Python

import marimo
__generated_with = "0.7.0"
app = marimo.App(width="medium")
@app.cell
def __():
import numpy as np
import matplotlib.pyplot as plt
import random
from matplotlib.colors import LinearSegmentedColormap
return LinearSegmentedColormap, np, plt, random
@app.cell
def __(plt):
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 16
return
@app.cell(hide_code=True)
def __(LinearSegmentedColormap, np, plt):
# 给定的 D 值列表
D_values = [4, 8]
# 创建一个图形和子图
fig, axs = plt.subplots(1, 2, figsize=(7, 3.7), dpi=400)
axs = axs.flatten() # 将 axs 数组展平,方便迭代
def smooth_curve(points, factor=0.8):
smoothed_points = []
for point in points:
if smoothed_points:
previous = smoothed_points[-1]
# 上一个节点*0.8+当前节点*0.2
smoothed_points.append(previous * factor + point * (1 - factor))
else:
# 添加point
smoothed_points.append(point)
return smoothed_points
c_1 = (230 / 255, 241 / 255, 243 / 255)
c_2 = (0, 0, 0)
c_3 = (255 / 255, 223 / 255, 146 / 255)
c_4 = (230 / 255, 109 / 255, 104 / 255)
for idx, D in enumerate(D_values):
idx = idx
matrix = np.zeros((D, D))
for N in range(1, D + 1):
for L in range(1, D + 1):
value = (D + N - 1 - L * N) / (D + N - 1)
matrix[N - 1, L - 1] = max(0, value)
colors = [
(1, 1, 1), # 黑色
(240 / 255, 175 / 255, 175 / 255), # 浅红色
(230 / 255, 109 / 255, 104 / 255), # 深红色
(100 / 255, 0 / 255, 0 / 255),
] # 红色
cmap = LinearSegmentedColormap.from_list("custom_red_black", colors, N=256)
step = max(int(2 ** (idx - 1)), 1)
start = int(2 ** (idx - 1) - 1)
axs[idx].set_xticks(
range(start, D + 1, step), range(start + 1, D + 2, step)
)
axs[idx].set_yticks(
range(start, D + 1, step), range(start + 1, D + 2, step)
)
cax = axs[idx].imshow(matrix, cmap=cmap, origin="lower")
# 在每个单元格添加数值标签
if idx < 1:
for i in range(D):
for j in range(D):
color = "white" if matrix[i, j] >= 0.6 else "black"
word = f"{matrix[i, j]:.2f}" if matrix[i, j] != 0 else ""
text = axs[idx].text(
j,
i,
word,
ha="center",
va="center",
color=color,
fontsize=14,
)
axs[0].set_title("Number of GPUs = 4", fontsize=16)
axs[1].set_title("Number of GPUs = 8", fontsize=16)
axs[0].set_ylabel("Number of micro-batches", fontsize=16)
axs[1].set_ylabel("Number of micro-batches", fontsize=16)
axs[0].set_xlabel(
" Number of simultaneously trained LoRA adapters",
fontsize=16,
)
plt.tight_layout()
colorbar = fig.colorbar(
cax, ax=axs, location="right", pad=0.01
)
colorbar.set_label("Bubble ratio", fontsize=16)
plt.show()
# plt.savefig("bubble.pdf", bbox_inches="tight", dpi=1000)
return (
D,
D_values,
L,
N,
axs,
c_1,
c_2,
c_3,
c_4,
cax,
cmap,
color,
colorbar,
colors,
fig,
i,
idx,
j,
matrix,
smooth_curve,
start,
step,
text,
value,
word,
)
if __name__ == "__main__":
app.run()