535 lines
14 KiB
Python
535 lines
14 KiB
Python
import marimo
|
||
|
||
__generated_with = "0.9.17"
|
||
app = marimo.App(width="medium")
|
||
|
||
|
||
@app.cell
|
||
def __():
|
||
import matplotlib
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
return matplotlib, np, plt
|
||
|
||
|
||
@app.cell
|
||
def __(matplotlib, plt):
|
||
matplotlib.rcParams["text.usetex"] = False
|
||
plt.rcParams["font.family"] = "Times New Roman"
|
||
plt.rcParams["font.size"] = 16
|
||
return
|
||
|
||
|
||
@app.cell
|
||
def __(np, plt):
|
||
x = np.arange(4)
|
||
|
||
fig, ax = plt.subplots(
|
||
figsize=(14, 14 / 2.2), ncols=4, nrows=3, layout="constrained"
|
||
)
|
||
|
||
c_1 = (139 / 255, 0 / 255, 0 / 255)
|
||
c_2 = (0, 0, 0)
|
||
c_3 = (191 / 255, 191 / 255, 191 / 255)
|
||
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
||
|
||
total_token = 7473 * 512 * 10
|
||
|
||
# A6000 单卡
|
||
b1_tp_m_s = [
|
||
2857.40228,
|
||
3016.124377,
|
||
3043.99588,
|
||
3047.335256,
|
||
3051.551977,
|
||
3051.512532,
|
||
3048.015064,
|
||
3047.108509,
|
||
3048.642661,
|
||
3051.840965,
|
||
3047.57159,
|
||
3047.861865,
|
||
]
|
||
b1_tp_p_s = [
|
||
2857.389389,
|
||
2842,
|
||
2851,
|
||
2847,
|
||
2853,
|
||
2841,
|
||
2843,
|
||
2851,
|
||
2850,
|
||
2851.3,
|
||
2849,
|
||
2852,
|
||
]
|
||
|
||
b7_tp_m_s = [702.5522131, 716.7349242, 722.2261427, 725.5761517, 727.0030057]
|
||
b7_tp_p_s = [702.6845352, 699.1034186, 701.3211972, 700.1283237, 700.1098232]
|
||
|
||
b13_tp_m_s = [398.8387303, 403.5820717, 405.8601994]
|
||
b13_tp_p_s = [398.7009553, 398.4052117, 399.1230098]
|
||
|
||
# A6000 4卡
|
||
b1_throughput_mLoRA = [
|
||
4600.45,
|
||
8664.91,
|
||
10118.36,
|
||
10184.44,
|
||
10119,
|
||
11157,
|
||
11157,
|
||
11530,
|
||
11580,
|
||
11580,
|
||
11600,
|
||
11600,
|
||
]
|
||
b1_throughput_tp = [
|
||
5752.14,
|
||
5749.34,
|
||
5756.78,
|
||
5758.32,
|
||
5753.14,
|
||
5753.34,
|
||
5756.78,
|
||
5756.32,
|
||
5758.14,
|
||
5749.34,
|
||
5753.78,
|
||
5754.32,
|
||
]
|
||
b1_throughput_fsdp = [
|
||
6151.91,
|
||
6141.73,
|
||
6161.23,
|
||
6153.93,
|
||
6153.91,
|
||
6146.73,
|
||
6161.23,
|
||
6151.93,
|
||
6157.91,
|
||
6143.73,
|
||
6161.23,
|
||
6153.93,
|
||
]
|
||
b1_throughput_gpipe = [
|
||
4599.87,
|
||
4610.19,
|
||
4592.17,
|
||
4601.18,
|
||
4598.87,
|
||
4600.19,
|
||
4593.17,
|
||
4601.18,
|
||
4599.87,
|
||
4610.19,
|
||
4592.17,
|
||
4603.18,
|
||
]
|
||
|
||
b7_throughput_mLoRA = [1274.87, 2250.46, 2362.69, 2363.89]
|
||
b7_throughput_tp = [1614.18, 1610.26, 1620.07, 1613.34]
|
||
b7_throughput_fsdp = [1695.37, 1705.97, 1686.05, 1693.45]
|
||
b7_throughput_gpipe = [1284.27, 1273.89, 1272.14, 1279.64]
|
||
|
||
b13_throughput_mLoRA = [723.21, 1280.54]
|
||
b13_throughput_tp = [875, 877]
|
||
b13_throughput_fsdp = [0, 0, 0, 0] # for the error
|
||
b13_throughput_gpipe = [723.21, 719.21]
|
||
|
||
b70_throughput_mLoRA = [234, 291, 320, 318]
|
||
b70_throughput_gpipe = [234.34, 234.32, 234.38, 234]
|
||
|
||
# 3090 8卡
|
||
b1_4090_mlora = [
|
||
319.61,
|
||
580.34,
|
||
663.12,
|
||
799.69,
|
||
800.96,
|
||
813.64,
|
||
812.92,
|
||
814.93,
|
||
]
|
||
b1_4090_tp = [35.17, 35.33, 35.32, 35.32, 35.38, 35.34, 35.33, 35.34]
|
||
b1_4090_fsdp = [62.79, 62.31, 60.03, 61.53, 62.79, 60.37, 61.23, 63.11]
|
||
b1_4090_gpipe = [
|
||
318.99,
|
||
319.16,
|
||
319.61,
|
||
319.42,
|
||
319.23,
|
||
319.61,
|
||
319.80,
|
||
319.96,
|
||
]
|
||
|
||
b7_4090_mlora = [576.79, 671.39, 702.91, 715.85]
|
||
b7_4090_gpipe = [578.97, 581.46, 573.70, 580.84]
|
||
b7_4090_fsdp = []
|
||
b7_4090_tp = [11.93, 12.09, 11.86, 12.10]
|
||
|
||
b13_4090_mlora = [524.57, 694.34]
|
||
b13_4090_gpipe = [528.47, 522.06]
|
||
b13_4090_fsdp = []
|
||
b13_4090_tp = []
|
||
# 绘制 A6000 4卡
|
||
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
||
ticks = [2, 4, 6, 8, 10, 12]
|
||
ticks_label = ["2", "4", "6", "8", "10", "12"]
|
||
m_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_mLoRA]
|
||
tp_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_tp]
|
||
fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_fsdp]
|
||
g_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_gpipe]
|
||
|
||
ax[0][0].plot(x, g_avg_time, color=c_1, marker="v")
|
||
ax[0][0].plot(x, fsdp_avg_time, color=c_3, marker="^")
|
||
ax[0][0].plot(x, tp_avg_time, color=c_2, marker="o")
|
||
ax[0][0].plot(x, m_avg_time, color=c_4, marker="*")
|
||
ax[0][0].set_xticks(ticks)
|
||
ax[0][0].set_xticklabels(ticks_label)
|
||
|
||
x = [1, 2, 3, 4]
|
||
ticks = [1, 2, 3, 4]
|
||
ticks_label = ["1", "2", "3", "4"]
|
||
|
||
m_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_mLoRA]
|
||
tp_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_tp]
|
||
fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_fsdp]
|
||
g_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_gpipe]
|
||
|
||
ax[0][1].plot(x, g_avg_time, color=c_1, marker="v", label="1F1B")
|
||
ax[0][1].plot(x, fsdp_avg_time, color=c_3, marker="^", label="FSDP")
|
||
ax[0][1].plot(x, tp_avg_time, color=c_2, marker="o", label="TP")
|
||
ax[0][1].plot(x, m_avg_time, color=c_4, marker="*")
|
||
ax[0][1].set_xticks(ticks)
|
||
ax[0][1].set_xticklabels(ticks_label)
|
||
|
||
x = [1, 2]
|
||
ticks = [1, 2]
|
||
ticks_label = ["1", "2"]
|
||
|
||
m_avg_time = [total_token / tp / 60 / 60 for tp in b13_throughput_mLoRA]
|
||
tp_avg_time = [total_token / tp / 60 / 60 for tp in b13_throughput_tp]
|
||
g_avg_time = [total_token / tp / 60 / 60 for tp in b13_throughput_gpipe]
|
||
|
||
ax[0][2].plot(x, g_avg_time, color=c_1, marker="v")
|
||
ax[0][2].plot(x, tp_avg_time, color=c_2, marker="o")
|
||
ax[0][2].plot(x, m_avg_time, color=c_4, marker="*")
|
||
ax[0][2].set_xticks(ticks)
|
||
ax[0][2].set_xticklabels(ticks_label)
|
||
|
||
x = [1, 2, 3, 4]
|
||
ticks = [1, 2, 3, 4]
|
||
ticks_label = ["1", "2", "3", "4"]
|
||
|
||
m_avg_time = [total_token / tp / 60 / 60 for tp in b70_throughput_mLoRA]
|
||
g_avg_time = [total_token / tp / 60 / 60 for tp in b70_throughput_gpipe]
|
||
|
||
ax[0][3].plot(x, g_avg_time, color=c_1, marker="v")
|
||
ax[0][3].plot(x, m_avg_time, color=c_4, marker="*")
|
||
ax[0][3].set_xticks(ticks)
|
||
ax[0][3].set_xticklabels(ticks_label)
|
||
|
||
## END
|
||
|
||
# 绘制 3090 8 卡
|
||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||
ticks = [2, 4, 6, 8]
|
||
ticks_label = ["2", "4", "6", "8"]
|
||
m_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_mlora]
|
||
tp_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_tp]
|
||
fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_fsdp]
|
||
g_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_gpipe]
|
||
|
||
ax[1][0].plot(x, g_avg_time, color=c_1, marker="v")
|
||
ax[1][0].plot(x, fsdp_avg_time, color=c_3, marker="^")
|
||
ax[1][0].plot(x, tp_avg_time, color=c_2, marker="o")
|
||
ax[1][0].plot(x, m_avg_time, color=c_4, marker="*")
|
||
ax[1][0].set_xticks(ticks)
|
||
ax[1][0].set_xticklabels(ticks_label)
|
||
|
||
x = [1, 2, 3, 4]
|
||
ticks = [1, 2, 3, 4]
|
||
ticks_label = ["1", "2", "3", "4"]
|
||
m_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_mlora]
|
||
tp_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_tp]
|
||
fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_fsdp]
|
||
g_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_gpipe]
|
||
|
||
ax[1][1].plot(x, g_avg_time, color=c_1, marker="v")
|
||
# ax[1][1].plot([1], fsdp_avg_time, color=c_3, marker="^")
|
||
# ax[1][1].plot(x, tp_avg_time, color=c_2, marker="o")
|
||
ax[1][1].plot(x, m_avg_time, color=c_4, marker="*")
|
||
ax[1][1].set_xticks(ticks)
|
||
ax[1][1].set_xticklabels(ticks_label)
|
||
|
||
x = [1, 2]
|
||
ticks = [1, 2]
|
||
ticks_label = ["1", "2"]
|
||
m_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_mlora]
|
||
tp_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_tp]
|
||
fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_fsdp]
|
||
g_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_gpipe]
|
||
|
||
ax[1][2].plot(x, g_avg_time, color=c_1, marker="v")
|
||
# ax[1][2].plot([1], fsdp_avg_time, color=c_3, marker="^")
|
||
# ax[1][2].plot(x, tp_avg_time, color=c_2, marker="o")
|
||
ax[1][2].plot(x, m_avg_time, color=c_4, marker="*")
|
||
ax[1][2].set_xticks(ticks)
|
||
ax[1][2].set_xticklabels(ticks_label)
|
||
|
||
x = [1, 2]
|
||
ticks = [1, 2]
|
||
ticks_label = ["1", "2"]
|
||
|
||
ax[1][3].set_xticks(ticks)
|
||
ax[1][3].set_xticklabels(ticks_label)
|
||
|
||
# END
|
||
|
||
# 绘制 A6000 单卡
|
||
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
||
ticks = [2, 4, 6, 8, 10, 12]
|
||
ticks_label = ["2", "4", "6", "8", "10", "12"]
|
||
m_avg_time = [total_token / tp / 60 / 60 for tp in b1_tp_m_s]
|
||
p_avg_time = [total_token / tp / 60 / 60 for tp in b1_tp_p_s]
|
||
ax[2][0].plot(x, p_avg_time, color=c_2, marker="^", label="PEFT")
|
||
ax[2][0].plot(x, m_avg_time, color=c_4, marker="*", label="mLoRA")
|
||
ax[2][0].set_xticks(ticks)
|
||
ax[2][0].set_xticklabels(ticks_label)
|
||
|
||
x = [1, 2, 3, 4, 5]
|
||
ticks = [1, 2, 3, 4, 5]
|
||
ticks_label = ["1", "2", "3", "4", "5"]
|
||
m_avg_time = [total_token / tp / 60 / 60 for tp in b7_tp_m_s]
|
||
p_avg_time = [total_token / tp / 60 / 60 for tp in b7_tp_p_s]
|
||
ax[2][1].plot(x, p_avg_time, color=c_2, marker="^")
|
||
ax[2][1].plot(x, m_avg_time, color=c_4, marker="*")
|
||
ax[2][1].set_xticks(ticks)
|
||
ax[2][1].set_xticklabels(ticks_label)
|
||
|
||
x = [1, 2, 3]
|
||
ticks = [1, 2, 3]
|
||
ticks_label = ["1", "2", "3"]
|
||
m_avg_time = [total_token / tp / 60 / 60 for tp in b13_tp_m_s]
|
||
p_avg_time = [total_token / tp / 60 / 60 for tp in b13_tp_p_s]
|
||
ax[2][2].plot(x, p_avg_time, color=c_2, marker="^")
|
||
ax[2][2].plot(x, m_avg_time, color=c_4, marker="*")
|
||
ax[2][2].set_xticks(ticks)
|
||
ax[2][2].set_xticklabels(ticks_label)
|
||
|
||
|
||
x = [1, 2]
|
||
ticks = [1, 2]
|
||
ticks_label = ["1", "2"]
|
||
|
||
ax[2][3].set_xticks(ticks)
|
||
ax[2][3].set_xticklabels(ticks_label)
|
||
## END
|
||
|
||
|
||
ax[0][0].set_ylim(0, 3)
|
||
ax[0][1].set_ylim(0, 10)
|
||
ax[0][2].set_ylim(0, 20)
|
||
ax[0][3].set_ylim(0, 60)
|
||
|
||
ax[1][0].set_ylim(0, 400)
|
||
ax[1][1].set_ylim(0, 40)
|
||
ax[1][2].set_ylim(0, 40)
|
||
ax[1][3].set_ylim(0, 40)
|
||
|
||
ax[2][0].set_ylim(3, 4)
|
||
ax[2][1].set_ylim(13, 16)
|
||
ax[2][2].set_ylim(25, 27)
|
||
ax[2][3].set_ylim(25, 27)
|
||
|
||
|
||
ax[0][2].set_xlim(0.8, 3 - 0.8)
|
||
|
||
ax[1][0].set_xlim(0.5, 9 - 0.5)
|
||
ax[1][1].set_xlim(0.5, 5 - 0.5)
|
||
ax[1][2].set_xlim(0.8, 3 - 0.8)
|
||
|
||
ax[2][2].set_xlim(0.8, 4 - 0.8)
|
||
|
||
ax[0][0].set_title("(a) 1.1B A6000×4", fontsize=16)
|
||
ax[0][1].set_title("(b) 7B A6000×4", fontsize=16)
|
||
ax[0][2].set_title("(c) 13B A6000×4", fontsize=16)
|
||
ax[0][3].set_title("(d) 70B A6000×4", fontsize=16)
|
||
|
||
ax[1][0].set_title("(e) 1.1B 3090×8", fontsize=16)
|
||
ax[1][1].set_title("(f) 7B 3090×8", fontsize=16)
|
||
ax[1][2].set_title("(g) 13B 3090×8", fontsize=16)
|
||
ax[1][3].set_title("(h) 70B 3090×8", fontsize=16)
|
||
|
||
ax[2][0].set_title("(i) 1.1B A6000", fontsize=16)
|
||
ax[2][1].set_title("(j) 7B A6000", fontsize=16)
|
||
ax[2][2].set_title("(k) 13B A6000", fontsize=16)
|
||
ax[2][3].set_title("(l) 70B A6000", fontsize=16)
|
||
|
||
ax[0][0].set_ylabel("Average task\ncompletion time (h)")
|
||
ax[1][0].set_ylabel("Average task\ncompletion time (h)")
|
||
ax[2][0].set_ylabel("Average task\ncompletion time (h)")
|
||
|
||
ax[0][2].text(
|
||
0.9,
|
||
0.8,
|
||
"FSDP : OOM",
|
||
fontsize=12,
|
||
va="bottom",
|
||
ha="right",
|
||
transform=ax[0][2].transAxes,
|
||
color=c_4,
|
||
)
|
||
|
||
ax[0][3].text(
|
||
0.9,
|
||
0.1,
|
||
"FSDP : OOM\nTP : OOM",
|
||
fontsize=12,
|
||
va="bottom",
|
||
ha="right",
|
||
transform=ax[0][3].transAxes,
|
||
color=c_4,
|
||
)
|
||
|
||
ax[1][1].text(
|
||
0.9,
|
||
0.7,
|
||
"FSDP : OOM\nTP : about one month",
|
||
fontsize=12,
|
||
va="bottom",
|
||
ha="right",
|
||
transform=ax[1][1].transAxes,
|
||
color=c_4,
|
||
)
|
||
|
||
ax[1][2].text(
|
||
0.9,
|
||
0.7,
|
||
"FSDP : OOM\nTP : OOM",
|
||
fontsize=12,
|
||
va="bottom",
|
||
ha="right",
|
||
transform=ax[1][2].transAxes,
|
||
color=c_4,
|
||
)
|
||
|
||
ax[1][3].text(
|
||
0.5,
|
||
0.5,
|
||
"FSDP : OOM\nTP : OOM\n1F1B : OOM\nmLoRA : OOM",
|
||
fontsize=12,
|
||
va="center",
|
||
ha="center",
|
||
transform=ax[1][3].transAxes,
|
||
color=c_4,
|
||
)
|
||
|
||
|
||
ax[2][3].text(
|
||
0.5,
|
||
0.5,
|
||
"PEFT : OOM\nmLoRA : OOM",
|
||
fontsize=12,
|
||
va="center",
|
||
ha="center",
|
||
transform=ax[2][3].transAxes,
|
||
color=c_4,
|
||
)
|
||
|
||
|
||
fig.legend(
|
||
ncol=5,
|
||
bbox_to_anchor=(0.75, 1.05),
|
||
fancybox=False,
|
||
framealpha=0.0,
|
||
fontsize=14,
|
||
)
|
||
|
||
|
||
ax[0][0].arrow(5, 0.5, 0.5, 0.3, width=0.01, head_width=0.1)
|
||
ax[0][0].text(
|
||
0.38,
|
||
0.07,
|
||
"enable BatchLoRA",
|
||
fontsize=10,
|
||
va="bottom",
|
||
ha="right",
|
||
transform=ax[0][0].transAxes,
|
||
color=c_4,
|
||
)
|
||
|
||
|
||
fig.supxlabel(
|
||
"Number of simultaneously trained LoRA adapters",
|
||
fontsize=16,
|
||
y=-0.03,
|
||
ha="center",
|
||
va="bottom",
|
||
)
|
||
|
||
|
||
plt.savefig("end-to-end.pdf", bbox_inches="tight", dpi=1000)
|
||
return (
|
||
ax,
|
||
b13_4090_fsdp,
|
||
b13_4090_gpipe,
|
||
b13_4090_mlora,
|
||
b13_4090_tp,
|
||
b13_throughput_fsdp,
|
||
b13_throughput_gpipe,
|
||
b13_throughput_mLoRA,
|
||
b13_throughput_tp,
|
||
b13_tp_m_s,
|
||
b13_tp_p_s,
|
||
b1_4090_fsdp,
|
||
b1_4090_gpipe,
|
||
b1_4090_mlora,
|
||
b1_4090_tp,
|
||
b1_throughput_fsdp,
|
||
b1_throughput_gpipe,
|
||
b1_throughput_mLoRA,
|
||
b1_throughput_tp,
|
||
b1_tp_m_s,
|
||
b1_tp_p_s,
|
||
b70_throughput_gpipe,
|
||
b70_throughput_mLoRA,
|
||
b7_4090_fsdp,
|
||
b7_4090_gpipe,
|
||
b7_4090_mlora,
|
||
b7_4090_tp,
|
||
b7_throughput_fsdp,
|
||
b7_throughput_gpipe,
|
||
b7_throughput_mLoRA,
|
||
b7_throughput_tp,
|
||
b7_tp_m_s,
|
||
b7_tp_p_s,
|
||
c_1,
|
||
c_2,
|
||
c_3,
|
||
c_4,
|
||
fig,
|
||
fsdp_avg_time,
|
||
g_avg_time,
|
||
m_avg_time,
|
||
p_avg_time,
|
||
ticks,
|
||
ticks_label,
|
||
total_token,
|
||
tp_avg_time,
|
||
x,
|
||
)
|
||
|
||
|
||
@app.cell
|
||
def __():
|
||
return
|
||
|
||
|
||
if __name__ == "__main__":
|
||
app.run()
|