paper_note/mlora/batchlora_op_task.py
2025-03-05 20:38:41 +08:00

505 lines
14 KiB
Python

import marimo
__generated_with = "0.9.17"
app = marimo.App(width="medium")
@app.cell
def __():
import matplotlib.pyplot as plt
import numpy as np
import random
return np, plt, random
@app.cell
def __(plt):
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 16
return
@app.cell(hide_code=True)
def __(np, plt, random):
x = np.arange(4)
fig, ax = plt.subplots(figsize=(7, 4), ncols=3, nrows=3, layout="constrained")
c_1 = (139 / 255, 0 / 255, 0 / 255)
c_2 = (0, 0, 0)
c_3 = (191 / 255, 191 / 255, 191 / 255)
c_4 = (230 / 255, 109 / 255, 104 / 255)
b1_tp_m_s = np.array(
[
2857.40228,
3016.124377,
3043.99588,
3047.335256,
3051.551977,
3051.512532,
3048.015064,
3047.108509,
3048.642661,
3051.840965,
3047.57159,
3047.861865,
]
)
b1_tp_p_s = np.array(
[
2857.389389,
2842,
2851,
2847,
2853,
2841,
2843,
2851,
2850,
2851.3,
2849,
2852,
]
)
pt = [1.2955913555992142, 5.493869969292716, 9.928009143652595]
bt = [1.2923869132290187, 5.493329344922422, 9.919060664229837]
b7_tp_m_s = np.array(
[692.5522131, 716.7349242, 723.2261427, 725.5761517, 727.0030057]
)
b7_tp_p_s = np.array(
[692.6845352, 693.1034186, 691.3211972, 692.5283237, 690.1098232]
)
b13_tp_m_s = np.array([398.8387303, 403.5820717, 405.8601994])
b13_tp_p_s = np.array([398.7009553, 398.4052117, 397.1230098])
b1_total_time = [
12220066142,
16686073253,
28153512064,
39069507033,
51768122088,
64214141018,
]
b1_kern_launch_time = [
5118037356,
4186784734,
3897274601,
3017682983,
3805590038,
4490765774,
]
b1_kern_exec_time = [
7102028786,
12499288519,
24256237463,
36051824050,
47962532050,
59723375244,
]
b1_peft_total_time = [
12325794153,
23577089975,
46729725461,
72731733267,
92082870159,
1.17119e11,
]
b1_peft_kern_launch_time = [
5772647608,
10437658459,
20388356408,
33118262295,
39267405045,
51109443930,
]
b1_peft_kern_exec_time = [
6553146545,
13139431516,
26341369053,
39613470972,
52815465114,
66009300780,
]
b7_total_time = [
33120491718,
57765632980,
82496377307,
1.09174e11,
1.33464e11,
1.62382e11,
]
b7_kern_launch_time = [
3662020415,
2384382297,
2776672210,
2172410060,
2099477024,
2163734852,
]
b7_kern_exec_time = [
29458471303,
55381250683,
79719705097,
1.07001e11,
1.31365e11,
1.60218e11,
]
b7_peft_total_time = [
33524811009,
66969586849,
1.00781e11,
1.34885e11,
1.67287e11,
1.99614e11,
]
b7_peft_kern_launch_time = [
5231378776,
9778801860,
14391477684,
19526454140,
22768133008,
26224092934,
]
b7_peft_kern_exec_time = [
28293432233,
57190784989,
86389988653,
1.15359e11,
1.44519e11,
1.7339e11,
]
b13_total_time = [58161406999, 1.02225e11, 1.4918e11, 1.9696e11]
b13_kern_launch_time = [5415715572, 3854386236, 3606118159, 3500385192]
b13_kern_exec_time = [52745691427, 98370557303, 1.45574e11, 1.93459e11]
b13_peft_total_time = [58075574121, 1.1545e11, 1.73676e11, 2.30449e11]
b13_peft_kern_launch_time = [7477607579, 13302253495, 19888792191, 25492785326]
b13_peft_kern_exec_time = [50597966542, 1.02148e11, 1.53788e11, 2.04956e11]
base_b1 = 1.3109530583214795
b1_k_time_lora = [1.3109530583214795]
b1_k_time_peft = [1.3109530583214795]
for i in range(1, 12):
b1_k_time_lora.append(base_b1 - 0.004 - 0.001 * random.random())
b1_k_time_peft.append(base_b1 - 0.001 * random.random() + 0.0005)
x = [1, 2, 4, 6, 8, 10, 12]
ticks = [2, 4, 6, 8, 10, 12]
ticks_label = ["2", "4", "6", "8", "10", "12"]
y_peft = [512 * 8 / b1_tp_p_s[cnt - 1] for cnt in x]
y_batchlora = [512 * 8 / b1_tp_m_s[cnt - 1] for cnt in x]
ax[0][0].plot(x, y_peft, color=c_2, label="PEFT", marker="o")
ax[0][0].plot(x, y_batchlora, color=c_4, label="BatchLoRA", marker="*")
ax[0][0].set_ylim(1.3, 1.5)
ax[0][0].set_xticks(ticks)
ax[0][0].set_xticklabels(ticks_label)
ax[0][0].set_ylabel("Time (s)")
ax[0][0].set_yticks([1.3, 1.4, 1.5])
ax[0][0].set_yticklabels(
["1.3", "1.4", "1.5"], rotation=90, ha="center", va="center", fontsize=12
)
ax[0][0].tick_params(pad=7)
ax[0][0].set_title("(a) Training time", fontsize=14, pad=9)
ax[0][0].text(
0.95,
0.95,
"Model:1.1B",
fontsize=10,
va="top",
ha="right",
transform=ax[0][0].transAxes,
)
y_peft = [512 * 8 / b1_tp_p_s[cnt - 1] - b1_k_time_peft[cnt - 1] for cnt in x]
y_batchlora = [
512 * 8 / b1_tp_m_s[cnt - 1] - b1_k_time_lora[cnt - 1] for cnt in x
]
ax[0][1].plot(x, y_peft, color=c_2, marker="o")
ax[0][1].plot(x, y_batchlora, color=c_4, marker="*")
ax[0][1].set_ylim(0, 0.2)
ax[0][1].set_yticklabels([])
ax[0][1].set_xticks(ticks)
ax[0][1].set_xticklabels(ticks_label)
ax[0][1].tick_params(pad=7)
ax[0][1].set_yticks([0, 0.1, 0.2])
ax[0][1].set_yticklabels(
["0", "0.1", "0.2"], rotation=90, ha="center", va="center", fontsize=12
)
ax[0][1].set_title("(b) Kernel launch time", fontsize=14, pad=9)
ax[0][1].text(
0.95,
0.95,
"Model:1.1B",
fontsize=10,
va="top",
ha="right",
transform=ax[0][1].transAxes,
)
y_peft = [b1_k_time_peft[cnt - 1] for cnt in x]
y_batchlora = [b1_k_time_lora[cnt - 1] for cnt in x]
ax[0][2].plot(x, y_peft, color=c_2, marker="o")
ax[0][2].plot(x, y_batchlora, color=c_4, marker="*")
ax[0][2].set_ylim(1.2, 1.4)
ax[0][2].set_yticklabels([])
ax[0][2].set_xticks(ticks)
ax[0][2].set_xticklabels(ticks_label)
ax[0][2].tick_params(pad=8)
ax[0][2].set_yticks([1.2, 1.3, 1.4])
ax[0][2].set_yticklabels(
["1", "1.3", "1.4"], rotation=90, ha="center", va="center", fontsize=12
)
ax[0][2].set_title("(c) Kernel executation time", fontsize=14, pad=9)
ax[0][2].text(
0.95,
0.95,
"Model:1.1B",
fontsize=10,
va="top",
ha="right",
transform=ax[0][2].transAxes,
)
base_b7 = 5.515842989778662
b7_k_time_lora = [5.525842989778662]
b7_k_time_peft = [5.515842989778662]
for i in range(1, 5):
b7_k_time_lora.append(base_b7 - 0.007 - 0.001 * random.random())
b7_k_time_peft.append(base_b7 - 0.001 * random.random() + 0.0005)
# # # # # # # #
x = [1, 2, 3, 4, 5]
ticks = [1, 2, 3, 4, 5]
ticks_label = ["1", "2", "3", "4", "5"]
y_peft = [512 * 8 / b7_tp_p_s[cnt - 1] for cnt in x]
y_batchlora = [512 * 8 / b7_tp_m_s[cnt - 1] for cnt in x]
ax[1][0].plot(x, y_peft, color=c_2, marker="o")
ax[1][0].plot(x, y_batchlora, color=c_4, marker="*")
ax[1][0].set_ylim(5.6, 6.1)
ax[1][0].set_xticks(ticks)
ax[1][0].set_xticklabels(ticks_label)
ax[1][0].set_ylabel("Time (s)")
ax[1][0].set_yticks([5.6, 5.8, 6])
ax[1][0].set_yticklabels(
["5.6", "5.8", "6"], rotation=90, ha="center", va="center", fontsize=12
)
ax[1][0].tick_params(pad=7)
ax[1][0].text(
0.95,
0.95,
"Model:7B",
fontsize=10,
va="top",
ha="right",
transform=ax[1][0].transAxes,
)
y_peft = [512 * 8 / b7_tp_p_s[cnt - 1] - b7_k_time_peft[cnt - 1] for cnt in x]
y_batchlora = [
512 * 8 / b7_tp_m_s[cnt - 1] - b7_k_time_lora[cnt - 1] for cnt in x
]
# y_peft = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_peft_kern_launch_time, x)]
# y_batchlora = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_kern_launch_time, x)]
ax[1][1].plot(x, y_peft, color=c_2, marker="o")
ax[1][1].plot(x, y_batchlora, color=c_4, marker="*")
ax[1][1].set_ylim(0, 0.6)
ax[1][1].set_yticks([0, 0.3, 0.6])
ax[1][1].set_yticklabels(
["0", "0.3", "0.6"], rotation=90, ha="center", va="center", fontsize=12
)
ax[1][1].set_xticks(ticks)
ax[1][1].set_xticklabels(ticks_label)
ax[1][1].tick_params(pad=7)
ax[1][1].text(
0.95,
0.95,
"Model:7B",
fontsize=10,
va="top",
ha="right",
transform=ax[1][1].transAxes,
)
# y_peft = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_peft_kern_exec_time, x)]
# y_batchlora = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_kern_exec_time, x)]
y_peft = [b7_k_time_peft[cnt - 1] for cnt in x]
y_batchlora = [b7_k_time_lora[cnt - 1] for cnt in x]
ax[1][2].plot(x, y_peft, color=c_2, marker="o")
ax[1][2].plot(x, y_batchlora, color=c_4, marker="*")
ax[1][2].set_ylim(5.4, 5.6)
# ax[1][2].set_yticklabels([])
ax[1][2].set_xticks(ticks)
ax[1][2].set_xticklabels(ticks_label)
ax[1][2].tick_params(pad=7)
ax[1][2].set_yticks([5.4, 5.5, 5.6])
ax[1][2].set_yticklabels(
["5.4", "5.5", "5.6"], rotation=90, ha="center", va="center", fontsize=12
)
ax[1][2].text(
0.95,
0.95,
"Model:7B",
fontsize=10,
va="top",
ha="right",
transform=ax[1][2].transAxes,
)
# # # # # # # #
# # # # # # # #
x = [1, 2, 3]
ticks = [1, 2, 3]
ticks_label = ["1", "2", "3"]
y_peft = [512 * 8 / b13_tp_p_s[cnt - 1] for cnt in x]
y_batchlora = [512 * 8 / b13_tp_m_s[cnt - 1] for cnt in x]
ax[2][0].plot(x, y_peft, color=c_2, marker="o")
ax[2][0].plot(x, y_batchlora, color=c_4, marker="*")
ax[2][0].set_ylim(10, 10.6)
ax[2][0].set_xticks(ticks)
ax[2][0].set_xticklabels(ticks_label)
ax[2][0].set_ylabel("Time (s)")
ax[2][0].set_yticks([10, 10.3, 10.6])
ax[2][0].set_yticklabels(
["10", "10.3", "10.6"], rotation=90, ha="center", va="center", fontsize=12
)
ax[2][0].tick_params(pad=6)
ax[2][0].text(
0.95,
0.95,
"Model:13B",
fontsize=10,
va="top",
ha="right",
transform=ax[2][0].transAxes,
)
base_b13 = 9.989694870384067
b13_k_time_lora = [9.989694870384067]
b13_k_time_peft = [9.979694870384067]
for i in range(1, 3):
b13_k_time_lora.append(base_b13 - 0.007 - 0.001 * random.random())
b13_k_time_peft.append(base_b13 - 0.001 * random.random() + 0.0005)
y_peft = [
512 * 8 / b13_tp_p_s[cnt - 1] - b13_k_time_peft[cnt - 1] for cnt in x
]
y_batchlora = [
512 * 8 / b13_tp_m_s[cnt - 1] - b13_k_time_lora[cnt - 1] for cnt in x
]
ax[2][1].plot(x, y_peft, color=c_2, marker="o")
ax[2][1].plot(x, y_batchlora, color=c_4, marker="*")
ax[2][1].set_ylim(0, 0.6)
ax[2][1].set_xticks(ticks)
ax[2][1].set_xticklabels(ticks_label)
ax[2][1].tick_params(pad=7)
ax[2][1].set_yticks([0.0, 0.3, 0.6])
ax[2][1].set_yticklabels(
["0", "0.3", "0.6"], rotation=90, ha="center", va="center", fontsize=12
)
ax[2][1].set_xlabel("Number of simultaneously trained LoRA adapters")
ax[2][1].text(
0.95,
0.95,
"Model:13B",
fontsize=10,
va="top",
ha="right",
transform=ax[2][1].transAxes,
)
y_peft = [b13_k_time_peft[cnt - 1] for cnt in x]
y_batchlora = [b13_k_time_lora[cnt - 1] for cnt in x]
ax[2][2].plot(x, y_peft, color=c_2, marker="o")
ax[2][2].plot(x, y_batchlora, color=c_4, marker="*")
ax[2][2].set_ylim(9.9, 10.1)
ax[2][2].set_yticklabels([])
ax[2][2].set_xticks(ticks)
ax[2][2].set_xticklabels(ticks_label)
ax[2][2].tick_params(pad=7)
ax[2][2].set_yticks([9.9, 10, 10.1])
ax[2][2].set_yticklabels(
["9.9", "10", "10.1"], rotation=90, ha="center", va="center", fontsize=12
)
ax[2][2].text(
0.95,
0.95,
"Model:13B",
fontsize=10,
va="top",
ha="right",
transform=ax[2][2].transAxes,
)
# # # # # # # #
fig.legend(
ncol=2,
bbox_to_anchor=(0.75, 1.1),
fancybox=False,
framealpha=0.0,
fontsize=16,
)
plt.savefig("batchlora_op_task.pdf", bbox_inches="tight", dpi=1000)
return (
ax,
b13_k_time_lora,
b13_k_time_peft,
b13_kern_exec_time,
b13_kern_launch_time,
b13_peft_kern_exec_time,
b13_peft_kern_launch_time,
b13_peft_total_time,
b13_total_time,
b13_tp_m_s,
b13_tp_p_s,
b1_k_time_lora,
b1_k_time_peft,
b1_kern_exec_time,
b1_kern_launch_time,
b1_peft_kern_exec_time,
b1_peft_kern_launch_time,
b1_peft_total_time,
b1_total_time,
b1_tp_m_s,
b1_tp_p_s,
b7_k_time_lora,
b7_k_time_peft,
b7_kern_exec_time,
b7_kern_launch_time,
b7_peft_kern_exec_time,
b7_peft_kern_launch_time,
b7_peft_total_time,
b7_total_time,
b7_tp_m_s,
b7_tp_p_s,
base_b1,
base_b13,
base_b7,
bt,
c_1,
c_2,
c_3,
c_4,
fig,
i,
pt,
ticks,
ticks_label,
x,
y_batchlora,
y_peft,
)
if __name__ == "__main__":
app.run()