paper_note/mlora/end-to-end-total.py
2025-03-05 20:38:41 +08:00

618 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import marimo
__generated_with = "0.9.17"
app = marimo.App(width="medium")
@app.cell
def __():
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
return matplotlib, np, plt
@app.cell
def __(matplotlib, plt):
matplotlib.rcParams["text.usetex"] = False
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 16
return
@app.cell
def __(np, plt):
x = np.arange(4)
fig, ax = plt.subplots(
figsize=(14, 14 / 2.2), ncols=4, nrows=3, layout="constrained"
)
c_1 = (139 / 255, 0 / 255, 0 / 255)
c_2 = (0, 0, 0)
c_3 = (191 / 255, 191 / 255, 191 / 255)
c_4 = (230 / 255, 109 / 255, 104 / 255)
total_token = 7473 * 512 * 10
# A6000 单卡
b1_tp_m_s = [
2857.40228,
3016.124377,
3043.99588,
3047.335256,
3051.551977,
3051.512532,
3048.015064,
3047.108509,
3048.642661,
3051.840965,
3047.57159,
3047.861865,
]
b1_tp_p_s = [
2857.389389,
2842,
2851,
2847,
2853,
2841,
2843,
2851,
2850,
2851.3,
2849,
2852,
]
b7_tp_m_s = [702.5522131, 716.7349242, 722.2261427, 725.5761517, 727.0030057]
b7_tp_p_s = [702.6845352, 699.1034186, 701.3211972, 700.1283237, 700.1098232]
b13_tp_m_s = [398.8387303, 403.5820717, 405.8601994]
b13_tp_p_s = [398.7009553, 398.4052117, 399.1230098]
# A6000 4卡
b1_throughput_mLoRA = [
4600.45,
8664.91,
10118.36,
10184.44,
10119,
11157,
11157,
11530,
11580,
11580,
11600,
11600,
]
b1_throughput_tp = [
5752.14,
5749.34,
5756.78,
5758.32,
5753.14,
5753.34,
5756.78,
5756.32,
5758.14,
5749.34,
5753.78,
5754.32,
]
b1_throughput_fsdp = [
6151.91,
6141.73,
6161.23,
6153.93,
6153.91,
6146.73,
6161.23,
6151.93,
6157.91,
6143.73,
6161.23,
6153.93,
]
b1_throughput_gpipe = [
4599.87,
4610.19,
4592.17,
4601.18,
4598.87,
4600.19,
4593.17,
4601.18,
4599.87,
4610.19,
4592.17,
4603.18,
]
b7_throughput_mLoRA = [1274.87, 2250.46, 2362.69, 2363.89]
b7_throughput_tp = [1614.18, 1610.26, 1620.07, 1613.34]
b7_throughput_fsdp = [1695.37, 1705.97, 1686.05, 1693.45]
b7_throughput_gpipe = [1284.27, 1273.89, 1272.14, 1279.64]
b13_throughput_mLoRA = [723.21, 1280.54]
b13_throughput_tp = [875, 877]
b13_throughput_fsdp = [0, 0, 0, 0] # for the error
b13_throughput_gpipe = [723.21, 719.21]
b70_throughput_mLoRA = [234, 291, 320, 318]
b70_throughput_gpipe = [234.34, 234.32, 234.38, 234]
# 3090 8卡
b1_4090_mlora = [
319.61,
580.34,
663.12,
799.69,
800.96,
813.64,
812.92,
814.93,
]
b1_4090_tp = [35.17, 35.33, 35.32, 35.32, 35.38, 35.34, 35.33, 35.34]
b1_4090_fsdp = [62.79, 62.31, 60.03, 61.53, 62.79, 60.37, 61.23, 63.11]
b1_4090_gpipe = [
318.99,
319.16,
319.61,
319.42,
319.23,
319.61,
319.80,
319.96,
]
b7_4090_mlora = [576.79, 671.39, 702.91, 715.85]
b7_4090_gpipe = [578.97, 581.46, 573.70, 580.84]
b7_4090_fsdp = []
b7_4090_tp = [11.93, 12.09, 11.86, 12.10]
b13_4090_mlora = [524.57, 694.34]
b13_4090_gpipe = [528.47, 522.06]
b13_4090_fsdp = []
b13_4090_tp = []
# 绘制 A6000 4卡
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
ticks = [2, 4, 6, 8, 10, 12]
ticks_label = ["2", "4", "6", "8", "10", "12"]
m_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b1_throughput_mLoRA)), b1_throughput_mLoRA)
]
tp_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b1_throughput_tp)), b1_throughput_tp)
]
fsdp_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b1_throughput_fsdp)), b1_throughput_fsdp)
]
g_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b1_throughput_gpipe)), b1_throughput_gpipe)
]
ax[0][0].plot(x, g_avg_time, color=c_1, marker="v")
ax[0][0].plot(x, fsdp_avg_time, color=c_3, marker="^")
ax[0][0].plot(x, tp_avg_time, color=c_2, marker="o")
ax[0][0].plot(x, m_avg_time, color=c_4, marker="*")
ax[0][0].set_xticks(ticks)
ax[0][0].set_xticklabels(ticks_label)
x = [1, 2, 3, 4]
ticks = [1, 2, 3, 4]
ticks_label = ["1", "2", "3", "4"]
m_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b7_throughput_mLoRA)), b7_throughput_mLoRA)
]
tp_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b7_throughput_tp)), b7_throughput_tp)
]
fsdp_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b7_throughput_fsdp)), b7_throughput_fsdp)
]
g_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b7_throughput_gpipe)), b7_throughput_gpipe)
]
ax[0][1].plot(x, g_avg_time, color=c_1, marker="v", label="1F1B")
ax[0][1].plot(x, fsdp_avg_time, color=c_3, marker="^", label="FSDP")
ax[0][1].plot(x, tp_avg_time, color=c_2, marker="o", label="TP")
ax[0][1].plot(x, m_avg_time, color=c_4, marker="*")
ax[0][1].set_xticks(ticks)
ax[0][1].set_xticklabels(ticks_label)
x = [1, 2]
ticks = [1, 2]
ticks_label = ["1", "2"]
m_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(
range(0, len(b13_throughput_mLoRA)), b13_throughput_mLoRA
)
]
tp_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b13_throughput_tp)), b13_throughput_tp)
]
g_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(
range(0, len(b13_throughput_gpipe)), b13_throughput_gpipe
)
]
ax[0][2].plot(x, g_avg_time, color=c_1, marker="v")
ax[0][2].plot(x, tp_avg_time, color=c_2, marker="o")
ax[0][2].plot(x, m_avg_time, color=c_4, marker="*")
ax[0][2].set_xticks(ticks)
ax[0][2].set_xticklabels(ticks_label)
x = [1, 2, 3, 4]
ticks = [1, 2, 3, 4]
ticks_label = ["1", "2", "3", "4"]
m_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(
range(0, len(b70_throughput_mLoRA)), b70_throughput_mLoRA
)
]
g_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(
range(0, len(b70_throughput_gpipe)), b70_throughput_gpipe
)
]
ax[0][3].plot(x, g_avg_time, color=c_1, marker="v")
ax[0][3].plot(x, m_avg_time, color=c_4, marker="*")
ax[0][3].set_xticks(ticks)
ax[0][3].set_xticklabels(ticks_label)
## END
# 绘制 3090 8 卡
x = [1, 2, 3, 4, 5, 6, 7, 8]
ticks = [2, 4, 6, 8]
ticks_label = ["2", "4", "6", "8"]
m_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b1_4090_mlora)), b1_4090_mlora)
]
tp_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b1_4090_tp)), b1_4090_tp)
]
fsdp_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b1_4090_fsdp)), b1_4090_fsdp)
]
g_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b1_4090_gpipe)), b1_4090_gpipe)
]
ax[1][0].plot(x, g_avg_time, color=c_1, marker="v")
ax[1][0].plot(x, fsdp_avg_time, color=c_3, marker="^")
ax[1][0].plot(x, tp_avg_time, color=c_2, marker="o")
ax[1][0].plot(x, m_avg_time, color=c_4, marker="*")
ax[1][0].set_xticks(ticks)
ax[1][0].set_xticklabels(ticks_label)
x = [1, 2, 3, 4]
ticks = [1, 2, 3, 4]
ticks_label = ["1", "2", "3", "4"]
m_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b7_4090_mlora)), b7_4090_mlora)
]
tp_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b7_4090_tp)), b7_4090_tp)
]
fsdp_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b7_4090_fsdp)), b7_4090_fsdp)
]
g_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b7_4090_gpipe)), b7_4090_gpipe)
]
ax[1][1].plot(x, g_avg_time, color=c_1, marker="v")
# ax[1][1].plot([1], fsdp_avg_time, color=c_3, marker="^")
# ax[1][1].plot(x, tp_avg_time, color=c_2, marker="o")
ax[1][1].plot(x, m_avg_time, color=c_4, marker="*")
ax[1][1].set_xticks(ticks)
ax[1][1].set_xticklabels(ticks_label)
x = [1, 2]
ticks = [1, 2]
ticks_label = ["1", "2"]
m_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b13_4090_mlora)), b13_4090_mlora)
]
tp_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b13_4090_tp)), b13_4090_tp)
]
fsdp_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b13_4090_fsdp)), b13_4090_fsdp)
]
g_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b13_4090_gpipe)), b13_4090_gpipe)
]
ax[1][2].plot(x, g_avg_time, color=c_1, marker="v")
# ax[1][2].plot([1], fsdp_avg_time, color=c_3, marker="^")
# ax[1][2].plot(x, tp_avg_time, color=c_2, marker="o")
ax[1][2].plot(x, m_avg_time, color=c_4, marker="*")
ax[1][2].set_xticks(ticks)
ax[1][2].set_xticklabels(ticks_label)
x = [1, 2]
ticks = [1, 2]
ticks_label = ["1", "2"]
ax[1][3].set_xticks(ticks)
ax[1][3].set_xticklabels(ticks_label)
# END
# 绘制 A6000 单卡
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
ticks = [2, 4, 6, 8, 10, 12]
ticks_label = ["2", "4", "6", "8", "10", "12"]
m_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b1_tp_m_s)), b1_tp_m_s)
]
p_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b1_tp_p_s)), b1_tp_p_s)
]
ax[2][0].plot(x, p_avg_time, color=c_2, marker="^", label="PEFT")
ax[2][0].plot(x, m_avg_time, color=c_4, marker="*", label="mLoRA")
ax[2][0].set_xticks(ticks)
ax[2][0].set_xticklabels(ticks_label)
x = [1, 2, 3, 4, 5]
ticks = [1, 2, 3, 4, 5]
ticks_label = ["1", "2", "3", "4", "5"]
m_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b7_tp_m_s)), b7_tp_m_s)
]
p_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b7_tp_p_s)), b7_tp_p_s)
]
ax[2][1].plot(x, p_avg_time, color=c_2, marker="^")
ax[2][1].plot(x, m_avg_time, color=c_4, marker="*")
ax[2][1].set_xticks(ticks)
ax[2][1].set_xticklabels(ticks_label)
x = [1, 2, 3]
ticks = [1, 2, 3]
ticks_label = ["1", "2", "3"]
m_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b13_tp_m_s)), b13_tp_m_s)
]
p_avg_time = [
total_token * (cnt + 1) / tp / 60 / 60
for cnt, tp in zip(range(0, len(b13_tp_p_s)), b13_tp_p_s)
]
ax[2][2].plot(x, p_avg_time, color=c_2, marker="^")
ax[2][2].plot(x, m_avg_time, color=c_4, marker="*")
ax[2][2].set_xticks(ticks)
ax[2][2].set_xticklabels(ticks_label)
x = [1, 2]
ticks = [1, 2]
ticks_label = ["1", "2"]
ax[2][3].set_xticks(ticks)
ax[2][3].set_xticklabels(ticks_label)
## END
ax[0][0].set_ylim(0, 30)
ax[0][1].set_ylim(0, 40)
ax[0][2].set_ylim(0, 40)
ax[0][3].set_ylim(0, 200)
ax[1][0].set_ylim(0, 2500)
ax[1][1].set_ylim(0, 80)
ax[1][2].set_ylim(0, 50)
ax[1][3].set_ylim(0, 1000)
ax[2][0].set_ylim(0, 50)
ax[2][1].set_ylim(0, 100)
ax[2][2].set_ylim(0, 100)
ax[2][3].set_ylim(0, 100)
ax[0][2].set_xlim(0.8, 3 - 0.8)
ax[1][0].set_xlim(0.5, 9 - 0.5)
ax[1][1].set_xlim(0.5, 5 - 0.5)
ax[1][2].set_xlim(0.8, 3 - 0.8)
ax[2][2].set_xlim(0.8, 4 - 0.8)
ax[0][0].set_title("(a) 1.1B A6000×4", fontsize=16)
ax[0][1].set_title("(b) 7B A6000×4", fontsize=16)
ax[0][2].set_title("(c) 13B A6000×4", fontsize=16)
ax[0][3].set_title("(d) 70B A6000×4", fontsize=16)
ax[1][0].set_title("(e) 1.1B 3090×8", fontsize=16)
ax[1][1].set_title("(f) 7B 3090×8", fontsize=16)
ax[1][2].set_title("(g) 13B 3090×8", fontsize=16)
ax[1][3].set_title("(h) 70B 3090×8", fontsize=16)
ax[2][0].set_title("(i) 1.1B A6000", fontsize=16)
ax[2][1].set_title("(j) 7B A6000", fontsize=16)
ax[2][2].set_title("(k) 13B A6000", fontsize=16)
ax[2][3].set_title("(l) 70B A6000", fontsize=16)
ax[0][0].set_ylabel("Task\ncompletion time (h)")
ax[1][0].set_ylabel("Task\ncompletion time (h)")
ax[2][0].set_ylabel("Task\ncompletion time (h)")
ax[0][2].text(
0.9,
0.8,
"FSDP : OOM",
fontsize=12,
va="bottom",
ha="right",
transform=ax[0][2].transAxes,
color=c_4,
)
ax[0][3].text(
0.9,
0.1,
"FSDP : OOM\nTP : OOM",
fontsize=12,
va="bottom",
ha="right",
transform=ax[0][3].transAxes,
color=c_4,
)
ax[1][1].text(
0.9,
0.1,
"FSDP : OOM\nTP : about one month",
fontsize=12,
va="bottom",
ha="right",
transform=ax[1][1].transAxes,
color=c_4,
)
ax[1][2].text(
0.9,
0.1,
"FSDP : OOM\nTP : OOM",
fontsize=12,
va="bottom",
ha="right",
transform=ax[1][2].transAxes,
color=c_4,
)
ax[1][3].text(
0.5,
0.5,
"FSDP : OOM\nTP : OOM\n1F1B : OOM\nmLoRA : OOM",
fontsize=12,
va="center",
ha="center",
transform=ax[1][3].transAxes,
color=c_4,
)
ax[2][3].text(
0.5,
0.5,
"PEFT : OOM\nmLoRA : OOM",
fontsize=12,
va="center",
ha="center",
transform=ax[2][3].transAxes,
color=c_4,
)
fig.legend(
ncol=5,
bbox_to_anchor=(0.75, 1.05),
fancybox=False,
framealpha=0.0,
fontsize=14,
)
fig.supxlabel(
"Number of trained LoRA adapters",
fontsize=16,
y=-0.03,
ha="center",
va="bottom",
)
plt.savefig("end-to-end-total.pdf", bbox_inches="tight", dpi=1000)
return (
ax,
b13_4090_fsdp,
b13_4090_gpipe,
b13_4090_mlora,
b13_4090_tp,
b13_throughput_fsdp,
b13_throughput_gpipe,
b13_throughput_mLoRA,
b13_throughput_tp,
b13_tp_m_s,
b13_tp_p_s,
b1_4090_fsdp,
b1_4090_gpipe,
b1_4090_mlora,
b1_4090_tp,
b1_throughput_fsdp,
b1_throughput_gpipe,
b1_throughput_mLoRA,
b1_throughput_tp,
b1_tp_m_s,
b1_tp_p_s,
b70_throughput_gpipe,
b70_throughput_mLoRA,
b7_4090_fsdp,
b7_4090_gpipe,
b7_4090_mlora,
b7_4090_tp,
b7_throughput_fsdp,
b7_throughput_gpipe,
b7_throughput_mLoRA,
b7_throughput_tp,
b7_tp_m_s,
b7_tp_p_s,
c_1,
c_2,
c_3,
c_4,
fig,
fsdp_avg_time,
g_avg_time,
m_avg_time,
p_avg_time,
ticks,
ticks_label,
total_token,
tp_avg_time,
x,
)
if __name__ == "__main__":
app.run()