paper_note/mlora/end_to_end.py
2025-03-05 20:38:41 +08:00

535 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import marimo
__generated_with = "0.9.17"
app = marimo.App(width="medium")
@app.cell
def __():
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
return matplotlib, np, plt
@app.cell
def __(matplotlib, plt):
matplotlib.rcParams["text.usetex"] = False
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 16
return
@app.cell
def __(np, plt):
x = np.arange(4)
fig, ax = plt.subplots(
figsize=(14, 14 / 2.2), ncols=4, nrows=3, layout="constrained"
)
c_1 = (139 / 255, 0 / 255, 0 / 255)
c_2 = (0, 0, 0)
c_3 = (191 / 255, 191 / 255, 191 / 255)
c_4 = (230 / 255, 109 / 255, 104 / 255)
total_token = 7473 * 512 * 10
# A6000 单卡
b1_tp_m_s = [
2857.40228,
3016.124377,
3043.99588,
3047.335256,
3051.551977,
3051.512532,
3048.015064,
3047.108509,
3048.642661,
3051.840965,
3047.57159,
3047.861865,
]
b1_tp_p_s = [
2857.389389,
2842,
2851,
2847,
2853,
2841,
2843,
2851,
2850,
2851.3,
2849,
2852,
]
b7_tp_m_s = [702.5522131, 716.7349242, 722.2261427, 725.5761517, 727.0030057]
b7_tp_p_s = [702.6845352, 699.1034186, 701.3211972, 700.1283237, 700.1098232]
b13_tp_m_s = [398.8387303, 403.5820717, 405.8601994]
b13_tp_p_s = [398.7009553, 398.4052117, 399.1230098]
# A6000 4卡
b1_throughput_mLoRA = [
4600.45,
8664.91,
10118.36,
10184.44,
10119,
11157,
11157,
11530,
11580,
11580,
11600,
11600,
]
b1_throughput_tp = [
5752.14,
5749.34,
5756.78,
5758.32,
5753.14,
5753.34,
5756.78,
5756.32,
5758.14,
5749.34,
5753.78,
5754.32,
]
b1_throughput_fsdp = [
6151.91,
6141.73,
6161.23,
6153.93,
6153.91,
6146.73,
6161.23,
6151.93,
6157.91,
6143.73,
6161.23,
6153.93,
]
b1_throughput_gpipe = [
4599.87,
4610.19,
4592.17,
4601.18,
4598.87,
4600.19,
4593.17,
4601.18,
4599.87,
4610.19,
4592.17,
4603.18,
]
b7_throughput_mLoRA = [1274.87, 2250.46, 2362.69, 2363.89]
b7_throughput_tp = [1614.18, 1610.26, 1620.07, 1613.34]
b7_throughput_fsdp = [1695.37, 1705.97, 1686.05, 1693.45]
b7_throughput_gpipe = [1284.27, 1273.89, 1272.14, 1279.64]
b13_throughput_mLoRA = [723.21, 1280.54]
b13_throughput_tp = [875, 877]
b13_throughput_fsdp = [0, 0, 0, 0] # for the error
b13_throughput_gpipe = [723.21, 719.21]
b70_throughput_mLoRA = [234, 291, 320, 318]
b70_throughput_gpipe = [234.34, 234.32, 234.38, 234]
# 3090 8卡
b1_4090_mlora = [
319.61,
580.34,
663.12,
799.69,
800.96,
813.64,
812.92,
814.93,
]
b1_4090_tp = [35.17, 35.33, 35.32, 35.32, 35.38, 35.34, 35.33, 35.34]
b1_4090_fsdp = [62.79, 62.31, 60.03, 61.53, 62.79, 60.37, 61.23, 63.11]
b1_4090_gpipe = [
318.99,
319.16,
319.61,
319.42,
319.23,
319.61,
319.80,
319.96,
]
b7_4090_mlora = [576.79, 671.39, 702.91, 715.85]
b7_4090_gpipe = [578.97, 581.46, 573.70, 580.84]
b7_4090_fsdp = []
b7_4090_tp = [11.93, 12.09, 11.86, 12.10]
b13_4090_mlora = [524.57, 694.34]
b13_4090_gpipe = [528.47, 522.06]
b13_4090_fsdp = []
b13_4090_tp = []
# 绘制 A6000 4卡
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
ticks = [2, 4, 6, 8, 10, 12]
ticks_label = ["2", "4", "6", "8", "10", "12"]
m_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_mLoRA]
tp_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_tp]
fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_fsdp]
g_avg_time = [total_token / tp / 60 / 60 for tp in b1_throughput_gpipe]
ax[0][0].plot(x, g_avg_time, color=c_1, marker="v")
ax[0][0].plot(x, fsdp_avg_time, color=c_3, marker="^")
ax[0][0].plot(x, tp_avg_time, color=c_2, marker="o")
ax[0][0].plot(x, m_avg_time, color=c_4, marker="*")
ax[0][0].set_xticks(ticks)
ax[0][0].set_xticklabels(ticks_label)
x = [1, 2, 3, 4]
ticks = [1, 2, 3, 4]
ticks_label = ["1", "2", "3", "4"]
m_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_mLoRA]
tp_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_tp]
fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_fsdp]
g_avg_time = [total_token / tp / 60 / 60 for tp in b7_throughput_gpipe]
ax[0][1].plot(x, g_avg_time, color=c_1, marker="v", label="1F1B")
ax[0][1].plot(x, fsdp_avg_time, color=c_3, marker="^", label="FSDP")
ax[0][1].plot(x, tp_avg_time, color=c_2, marker="o", label="TP")
ax[0][1].plot(x, m_avg_time, color=c_4, marker="*")
ax[0][1].set_xticks(ticks)
ax[0][1].set_xticklabels(ticks_label)
x = [1, 2]
ticks = [1, 2]
ticks_label = ["1", "2"]
m_avg_time = [total_token / tp / 60 / 60 for tp in b13_throughput_mLoRA]
tp_avg_time = [total_token / tp / 60 / 60 for tp in b13_throughput_tp]
g_avg_time = [total_token / tp / 60 / 60 for tp in b13_throughput_gpipe]
ax[0][2].plot(x, g_avg_time, color=c_1, marker="v")
ax[0][2].plot(x, tp_avg_time, color=c_2, marker="o")
ax[0][2].plot(x, m_avg_time, color=c_4, marker="*")
ax[0][2].set_xticks(ticks)
ax[0][2].set_xticklabels(ticks_label)
x = [1, 2, 3, 4]
ticks = [1, 2, 3, 4]
ticks_label = ["1", "2", "3", "4"]
m_avg_time = [total_token / tp / 60 / 60 for tp in b70_throughput_mLoRA]
g_avg_time = [total_token / tp / 60 / 60 for tp in b70_throughput_gpipe]
ax[0][3].plot(x, g_avg_time, color=c_1, marker="v")
ax[0][3].plot(x, m_avg_time, color=c_4, marker="*")
ax[0][3].set_xticks(ticks)
ax[0][3].set_xticklabels(ticks_label)
## END
# 绘制 3090 8 卡
x = [1, 2, 3, 4, 5, 6, 7, 8]
ticks = [2, 4, 6, 8]
ticks_label = ["2", "4", "6", "8"]
m_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_mlora]
tp_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_tp]
fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_fsdp]
g_avg_time = [total_token / tp / 60 / 60 for tp in b1_4090_gpipe]
ax[1][0].plot(x, g_avg_time, color=c_1, marker="v")
ax[1][0].plot(x, fsdp_avg_time, color=c_3, marker="^")
ax[1][0].plot(x, tp_avg_time, color=c_2, marker="o")
ax[1][0].plot(x, m_avg_time, color=c_4, marker="*")
ax[1][0].set_xticks(ticks)
ax[1][0].set_xticklabels(ticks_label)
x = [1, 2, 3, 4]
ticks = [1, 2, 3, 4]
ticks_label = ["1", "2", "3", "4"]
m_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_mlora]
tp_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_tp]
fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_fsdp]
g_avg_time = [total_token / tp / 60 / 60 for tp in b7_4090_gpipe]
ax[1][1].plot(x, g_avg_time, color=c_1, marker="v")
# ax[1][1].plot([1], fsdp_avg_time, color=c_3, marker="^")
# ax[1][1].plot(x, tp_avg_time, color=c_2, marker="o")
ax[1][1].plot(x, m_avg_time, color=c_4, marker="*")
ax[1][1].set_xticks(ticks)
ax[1][1].set_xticklabels(ticks_label)
x = [1, 2]
ticks = [1, 2]
ticks_label = ["1", "2"]
m_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_mlora]
tp_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_tp]
fsdp_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_fsdp]
g_avg_time = [total_token / tp / 60 / 60 for tp in b13_4090_gpipe]
ax[1][2].plot(x, g_avg_time, color=c_1, marker="v")
# ax[1][2].plot([1], fsdp_avg_time, color=c_3, marker="^")
# ax[1][2].plot(x, tp_avg_time, color=c_2, marker="o")
ax[1][2].plot(x, m_avg_time, color=c_4, marker="*")
ax[1][2].set_xticks(ticks)
ax[1][2].set_xticklabels(ticks_label)
x = [1, 2]
ticks = [1, 2]
ticks_label = ["1", "2"]
ax[1][3].set_xticks(ticks)
ax[1][3].set_xticklabels(ticks_label)
# END
# 绘制 A6000 单卡
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
ticks = [2, 4, 6, 8, 10, 12]
ticks_label = ["2", "4", "6", "8", "10", "12"]
m_avg_time = [total_token / tp / 60 / 60 for tp in b1_tp_m_s]
p_avg_time = [total_token / tp / 60 / 60 for tp in b1_tp_p_s]
ax[2][0].plot(x, p_avg_time, color=c_2, marker="^", label="PEFT")
ax[2][0].plot(x, m_avg_time, color=c_4, marker="*", label="mLoRA")
ax[2][0].set_xticks(ticks)
ax[2][0].set_xticklabels(ticks_label)
x = [1, 2, 3, 4, 5]
ticks = [1, 2, 3, 4, 5]
ticks_label = ["1", "2", "3", "4", "5"]
m_avg_time = [total_token / tp / 60 / 60 for tp in b7_tp_m_s]
p_avg_time = [total_token / tp / 60 / 60 for tp in b7_tp_p_s]
ax[2][1].plot(x, p_avg_time, color=c_2, marker="^")
ax[2][1].plot(x, m_avg_time, color=c_4, marker="*")
ax[2][1].set_xticks(ticks)
ax[2][1].set_xticklabels(ticks_label)
x = [1, 2, 3]
ticks = [1, 2, 3]
ticks_label = ["1", "2", "3"]
m_avg_time = [total_token / tp / 60 / 60 for tp in b13_tp_m_s]
p_avg_time = [total_token / tp / 60 / 60 for tp in b13_tp_p_s]
ax[2][2].plot(x, p_avg_time, color=c_2, marker="^")
ax[2][2].plot(x, m_avg_time, color=c_4, marker="*")
ax[2][2].set_xticks(ticks)
ax[2][2].set_xticklabels(ticks_label)
x = [1, 2]
ticks = [1, 2]
ticks_label = ["1", "2"]
ax[2][3].set_xticks(ticks)
ax[2][3].set_xticklabels(ticks_label)
## END
ax[0][0].set_ylim(0, 3)
ax[0][1].set_ylim(0, 10)
ax[0][2].set_ylim(0, 20)
ax[0][3].set_ylim(0, 60)
ax[1][0].set_ylim(0, 400)
ax[1][1].set_ylim(0, 40)
ax[1][2].set_ylim(0, 40)
ax[1][3].set_ylim(0, 40)
ax[2][0].set_ylim(3, 4)
ax[2][1].set_ylim(13, 16)
ax[2][2].set_ylim(25, 27)
ax[2][3].set_ylim(25, 27)
ax[0][2].set_xlim(0.8, 3 - 0.8)
ax[1][0].set_xlim(0.5, 9 - 0.5)
ax[1][1].set_xlim(0.5, 5 - 0.5)
ax[1][2].set_xlim(0.8, 3 - 0.8)
ax[2][2].set_xlim(0.8, 4 - 0.8)
ax[0][0].set_title("(a) 1.1B A6000×4", fontsize=16)
ax[0][1].set_title("(b) 7B A6000×4", fontsize=16)
ax[0][2].set_title("(c) 13B A6000×4", fontsize=16)
ax[0][3].set_title("(d) 70B A6000×4", fontsize=16)
ax[1][0].set_title("(e) 1.1B 3090×8", fontsize=16)
ax[1][1].set_title("(f) 7B 3090×8", fontsize=16)
ax[1][2].set_title("(g) 13B 3090×8", fontsize=16)
ax[1][3].set_title("(h) 70B 3090×8", fontsize=16)
ax[2][0].set_title("(i) 1.1B A6000", fontsize=16)
ax[2][1].set_title("(j) 7B A6000", fontsize=16)
ax[2][2].set_title("(k) 13B A6000", fontsize=16)
ax[2][3].set_title("(l) 70B A6000", fontsize=16)
ax[0][0].set_ylabel("Average task\ncompletion time (h)")
ax[1][0].set_ylabel("Average task\ncompletion time (h)")
ax[2][0].set_ylabel("Average task\ncompletion time (h)")
ax[0][2].text(
0.9,
0.8,
"FSDP : OOM",
fontsize=12,
va="bottom",
ha="right",
transform=ax[0][2].transAxes,
color=c_4,
)
ax[0][3].text(
0.9,
0.1,
"FSDP : OOM\nTP : OOM",
fontsize=12,
va="bottom",
ha="right",
transform=ax[0][3].transAxes,
color=c_4,
)
ax[1][1].text(
0.9,
0.7,
"FSDP : OOM\nTP : about one month",
fontsize=12,
va="bottom",
ha="right",
transform=ax[1][1].transAxes,
color=c_4,
)
ax[1][2].text(
0.9,
0.7,
"FSDP : OOM\nTP : OOM",
fontsize=12,
va="bottom",
ha="right",
transform=ax[1][2].transAxes,
color=c_4,
)
ax[1][3].text(
0.5,
0.5,
"FSDP : OOM\nTP : OOM\n1F1B : OOM\nmLoRA : OOM",
fontsize=12,
va="center",
ha="center",
transform=ax[1][3].transAxes,
color=c_4,
)
ax[2][3].text(
0.5,
0.5,
"PEFT : OOM\nmLoRA : OOM",
fontsize=12,
va="center",
ha="center",
transform=ax[2][3].transAxes,
color=c_4,
)
fig.legend(
ncol=5,
bbox_to_anchor=(0.75, 1.05),
fancybox=False,
framealpha=0.0,
fontsize=14,
)
ax[0][0].arrow(5, 0.5, 0.5, 0.3, width=0.01, head_width=0.1)
ax[0][0].text(
0.38,
0.07,
"enable BatchLoRA",
fontsize=10,
va="bottom",
ha="right",
transform=ax[0][0].transAxes,
color=c_4,
)
fig.supxlabel(
"Number of simultaneously trained LoRA adapters",
fontsize=16,
y=-0.03,
ha="center",
va="bottom",
)
plt.savefig("end-to-end.pdf", bbox_inches="tight", dpi=1000)
return (
ax,
b13_4090_fsdp,
b13_4090_gpipe,
b13_4090_mlora,
b13_4090_tp,
b13_throughput_fsdp,
b13_throughput_gpipe,
b13_throughput_mLoRA,
b13_throughput_tp,
b13_tp_m_s,
b13_tp_p_s,
b1_4090_fsdp,
b1_4090_gpipe,
b1_4090_mlora,
b1_4090_tp,
b1_throughput_fsdp,
b1_throughput_gpipe,
b1_throughput_mLoRA,
b1_throughput_tp,
b1_tp_m_s,
b1_tp_p_s,
b70_throughput_gpipe,
b70_throughput_mLoRA,
b7_4090_fsdp,
b7_4090_gpipe,
b7_4090_mlora,
b7_4090_tp,
b7_throughput_fsdp,
b7_throughput_gpipe,
b7_throughput_mLoRA,
b7_throughput_tp,
b7_tp_m_s,
b7_tp_p_s,
c_1,
c_2,
c_3,
c_4,
fig,
fsdp_avg_time,
g_avg_time,
m_avg_time,
p_avg_time,
ticks,
ticks_label,
total_token,
tp_avg_time,
x,
)
@app.cell
def __():
return
if __name__ == "__main__":
app.run()