505 lines
14 KiB
Python
505 lines
14 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.9.17"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell
|
|
def __():
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import random
|
|
return np, plt, random
|
|
|
|
|
|
@app.cell
|
|
def __(plt):
|
|
plt.rcParams['font.family'] = 'Times New Roman'
|
|
plt.rcParams['font.size'] = 16
|
|
return
|
|
|
|
|
|
@app.cell(hide_code=True)
|
|
def __(np, plt, random):
|
|
x = np.arange(4)
|
|
|
|
fig, ax = plt.subplots(figsize=(7, 4), ncols=3, nrows=3, layout="constrained")
|
|
|
|
c_1 = (139 / 255, 0 / 255, 0 / 255)
|
|
c_2 = (0, 0, 0)
|
|
c_3 = (191 / 255, 191 / 255, 191 / 255)
|
|
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
|
|
|
b1_tp_m_s = np.array(
|
|
[
|
|
2857.40228,
|
|
3016.124377,
|
|
3043.99588,
|
|
3047.335256,
|
|
3051.551977,
|
|
3051.512532,
|
|
3048.015064,
|
|
3047.108509,
|
|
3048.642661,
|
|
3051.840965,
|
|
3047.57159,
|
|
3047.861865,
|
|
]
|
|
)
|
|
b1_tp_p_s = np.array(
|
|
[
|
|
2857.389389,
|
|
2842,
|
|
2851,
|
|
2847,
|
|
2853,
|
|
2841,
|
|
2843,
|
|
2851,
|
|
2850,
|
|
2851.3,
|
|
2849,
|
|
2852,
|
|
]
|
|
)
|
|
|
|
pt = [1.2955913555992142, 5.493869969292716, 9.928009143652595]
|
|
bt = [1.2923869132290187, 5.493329344922422, 9.919060664229837]
|
|
|
|
b7_tp_m_s = np.array(
|
|
[692.5522131, 716.7349242, 723.2261427, 725.5761517, 727.0030057]
|
|
)
|
|
b7_tp_p_s = np.array(
|
|
[692.6845352, 693.1034186, 691.3211972, 692.5283237, 690.1098232]
|
|
)
|
|
|
|
b13_tp_m_s = np.array([398.8387303, 403.5820717, 405.8601994])
|
|
b13_tp_p_s = np.array([398.7009553, 398.4052117, 397.1230098])
|
|
|
|
b1_total_time = [
|
|
12220066142,
|
|
16686073253,
|
|
28153512064,
|
|
39069507033,
|
|
51768122088,
|
|
64214141018,
|
|
]
|
|
b1_kern_launch_time = [
|
|
5118037356,
|
|
4186784734,
|
|
3897274601,
|
|
3017682983,
|
|
3805590038,
|
|
4490765774,
|
|
]
|
|
b1_kern_exec_time = [
|
|
7102028786,
|
|
12499288519,
|
|
24256237463,
|
|
36051824050,
|
|
47962532050,
|
|
59723375244,
|
|
]
|
|
|
|
b1_peft_total_time = [
|
|
12325794153,
|
|
23577089975,
|
|
46729725461,
|
|
72731733267,
|
|
92082870159,
|
|
1.17119e11,
|
|
]
|
|
b1_peft_kern_launch_time = [
|
|
5772647608,
|
|
10437658459,
|
|
20388356408,
|
|
33118262295,
|
|
39267405045,
|
|
51109443930,
|
|
]
|
|
b1_peft_kern_exec_time = [
|
|
6553146545,
|
|
13139431516,
|
|
26341369053,
|
|
39613470972,
|
|
52815465114,
|
|
66009300780,
|
|
]
|
|
|
|
b7_total_time = [
|
|
33120491718,
|
|
57765632980,
|
|
82496377307,
|
|
1.09174e11,
|
|
1.33464e11,
|
|
1.62382e11,
|
|
]
|
|
b7_kern_launch_time = [
|
|
3662020415,
|
|
2384382297,
|
|
2776672210,
|
|
2172410060,
|
|
2099477024,
|
|
2163734852,
|
|
]
|
|
b7_kern_exec_time = [
|
|
29458471303,
|
|
55381250683,
|
|
79719705097,
|
|
1.07001e11,
|
|
1.31365e11,
|
|
1.60218e11,
|
|
]
|
|
|
|
b7_peft_total_time = [
|
|
33524811009,
|
|
66969586849,
|
|
1.00781e11,
|
|
1.34885e11,
|
|
1.67287e11,
|
|
1.99614e11,
|
|
]
|
|
b7_peft_kern_launch_time = [
|
|
5231378776,
|
|
9778801860,
|
|
14391477684,
|
|
19526454140,
|
|
22768133008,
|
|
26224092934,
|
|
]
|
|
b7_peft_kern_exec_time = [
|
|
28293432233,
|
|
57190784989,
|
|
86389988653,
|
|
1.15359e11,
|
|
1.44519e11,
|
|
1.7339e11,
|
|
]
|
|
|
|
b13_total_time = [58161406999, 1.02225e11, 1.4918e11, 1.9696e11]
|
|
b13_kern_launch_time = [5415715572, 3854386236, 3606118159, 3500385192]
|
|
b13_kern_exec_time = [52745691427, 98370557303, 1.45574e11, 1.93459e11]
|
|
|
|
b13_peft_total_time = [58075574121, 1.1545e11, 1.73676e11, 2.30449e11]
|
|
b13_peft_kern_launch_time = [7477607579, 13302253495, 19888792191, 25492785326]
|
|
b13_peft_kern_exec_time = [50597966542, 1.02148e11, 1.53788e11, 2.04956e11]
|
|
base_b1 = 1.3109530583214795
|
|
b1_k_time_lora = [1.3109530583214795]
|
|
b1_k_time_peft = [1.3109530583214795]
|
|
for i in range(1, 12):
|
|
b1_k_time_lora.append(base_b1 - 0.004 - 0.001 * random.random())
|
|
b1_k_time_peft.append(base_b1 - 0.001 * random.random() + 0.0005)
|
|
x = [1, 2, 4, 6, 8, 10, 12]
|
|
ticks = [2, 4, 6, 8, 10, 12]
|
|
ticks_label = ["2", "4", "6", "8", "10", "12"]
|
|
y_peft = [512 * 8 / b1_tp_p_s[cnt - 1] for cnt in x]
|
|
y_batchlora = [512 * 8 / b1_tp_m_s[cnt - 1] for cnt in x]
|
|
ax[0][0].plot(x, y_peft, color=c_2, label="PEFT", marker="o")
|
|
ax[0][0].plot(x, y_batchlora, color=c_4, label="BatchLoRA", marker="*")
|
|
ax[0][0].set_ylim(1.3, 1.5)
|
|
ax[0][0].set_xticks(ticks)
|
|
ax[0][0].set_xticklabels(ticks_label)
|
|
ax[0][0].set_ylabel("Time (s)")
|
|
ax[0][0].set_yticks([1.3, 1.4, 1.5])
|
|
ax[0][0].set_yticklabels(
|
|
["1.3", "1.4", "1.5"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[0][0].tick_params(pad=7)
|
|
ax[0][0].set_title("(a) Training time", fontsize=14, pad=9)
|
|
ax[0][0].text(
|
|
0.95,
|
|
0.95,
|
|
"Model:1.1B",
|
|
fontsize=10,
|
|
va="top",
|
|
ha="right",
|
|
transform=ax[0][0].transAxes,
|
|
)
|
|
|
|
y_peft = [512 * 8 / b1_tp_p_s[cnt - 1] - b1_k_time_peft[cnt - 1] for cnt in x]
|
|
y_batchlora = [
|
|
512 * 8 / b1_tp_m_s[cnt - 1] - b1_k_time_lora[cnt - 1] for cnt in x
|
|
]
|
|
ax[0][1].plot(x, y_peft, color=c_2, marker="o")
|
|
ax[0][1].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[0][1].set_ylim(0, 0.2)
|
|
ax[0][1].set_yticklabels([])
|
|
ax[0][1].set_xticks(ticks)
|
|
ax[0][1].set_xticklabels(ticks_label)
|
|
ax[0][1].tick_params(pad=7)
|
|
ax[0][1].set_yticks([0, 0.1, 0.2])
|
|
ax[0][1].set_yticklabels(
|
|
["0", "0.1", "0.2"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[0][1].set_title("(b) Kernel launch time", fontsize=14, pad=9)
|
|
ax[0][1].text(
|
|
0.95,
|
|
0.95,
|
|
"Model:1.1B",
|
|
fontsize=10,
|
|
va="top",
|
|
ha="right",
|
|
transform=ax[0][1].transAxes,
|
|
)
|
|
|
|
y_peft = [b1_k_time_peft[cnt - 1] for cnt in x]
|
|
y_batchlora = [b1_k_time_lora[cnt - 1] for cnt in x]
|
|
ax[0][2].plot(x, y_peft, color=c_2, marker="o")
|
|
ax[0][2].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[0][2].set_ylim(1.2, 1.4)
|
|
ax[0][2].set_yticklabels([])
|
|
ax[0][2].set_xticks(ticks)
|
|
ax[0][2].set_xticklabels(ticks_label)
|
|
ax[0][2].tick_params(pad=8)
|
|
ax[0][2].set_yticks([1.2, 1.3, 1.4])
|
|
ax[0][2].set_yticklabels(
|
|
["1", "1.3", "1.4"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[0][2].set_title("(c) Kernel executation time", fontsize=14, pad=9)
|
|
ax[0][2].text(
|
|
0.95,
|
|
0.95,
|
|
"Model:1.1B",
|
|
fontsize=10,
|
|
va="top",
|
|
ha="right",
|
|
transform=ax[0][2].transAxes,
|
|
)
|
|
|
|
base_b7 = 5.515842989778662
|
|
b7_k_time_lora = [5.525842989778662]
|
|
b7_k_time_peft = [5.515842989778662]
|
|
for i in range(1, 5):
|
|
b7_k_time_lora.append(base_b7 - 0.007 - 0.001 * random.random())
|
|
b7_k_time_peft.append(base_b7 - 0.001 * random.random() + 0.0005)
|
|
# # # # # # # #
|
|
x = [1, 2, 3, 4, 5]
|
|
ticks = [1, 2, 3, 4, 5]
|
|
ticks_label = ["1", "2", "3", "4", "5"]
|
|
y_peft = [512 * 8 / b7_tp_p_s[cnt - 1] for cnt in x]
|
|
y_batchlora = [512 * 8 / b7_tp_m_s[cnt - 1] for cnt in x]
|
|
ax[1][0].plot(x, y_peft, color=c_2, marker="o")
|
|
ax[1][0].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[1][0].set_ylim(5.6, 6.1)
|
|
ax[1][0].set_xticks(ticks)
|
|
ax[1][0].set_xticklabels(ticks_label)
|
|
ax[1][0].set_ylabel("Time (s)")
|
|
ax[1][0].set_yticks([5.6, 5.8, 6])
|
|
ax[1][0].set_yticklabels(
|
|
["5.6", "5.8", "6"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[1][0].tick_params(pad=7)
|
|
ax[1][0].text(
|
|
0.95,
|
|
0.95,
|
|
"Model:7B",
|
|
fontsize=10,
|
|
va="top",
|
|
ha="right",
|
|
transform=ax[1][0].transAxes,
|
|
)
|
|
|
|
y_peft = [512 * 8 / b7_tp_p_s[cnt - 1] - b7_k_time_peft[cnt - 1] for cnt in x]
|
|
y_batchlora = [
|
|
512 * 8 / b7_tp_m_s[cnt - 1] - b7_k_time_lora[cnt - 1] for cnt in x
|
|
]
|
|
# y_peft = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_peft_kern_launch_time, x)]
|
|
# y_batchlora = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_kern_launch_time, x)]
|
|
ax[1][1].plot(x, y_peft, color=c_2, marker="o")
|
|
ax[1][1].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[1][1].set_ylim(0, 0.6)
|
|
ax[1][1].set_yticks([0, 0.3, 0.6])
|
|
ax[1][1].set_yticklabels(
|
|
["0", "0.3", "0.6"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[1][1].set_xticks(ticks)
|
|
ax[1][1].set_xticklabels(ticks_label)
|
|
ax[1][1].tick_params(pad=7)
|
|
ax[1][1].text(
|
|
0.95,
|
|
0.95,
|
|
"Model:7B",
|
|
fontsize=10,
|
|
va="top",
|
|
ha="right",
|
|
transform=ax[1][1].transAxes,
|
|
)
|
|
|
|
|
|
# y_peft = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_peft_kern_exec_time, x)]
|
|
# y_batchlora = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_kern_exec_time, x)]
|
|
y_peft = [b7_k_time_peft[cnt - 1] for cnt in x]
|
|
y_batchlora = [b7_k_time_lora[cnt - 1] for cnt in x]
|
|
ax[1][2].plot(x, y_peft, color=c_2, marker="o")
|
|
ax[1][2].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[1][2].set_ylim(5.4, 5.6)
|
|
# ax[1][2].set_yticklabels([])
|
|
ax[1][2].set_xticks(ticks)
|
|
ax[1][2].set_xticklabels(ticks_label)
|
|
ax[1][2].tick_params(pad=7)
|
|
ax[1][2].set_yticks([5.4, 5.5, 5.6])
|
|
ax[1][2].set_yticklabels(
|
|
["5.4", "5.5", "5.6"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[1][2].text(
|
|
0.95,
|
|
0.95,
|
|
"Model:7B",
|
|
fontsize=10,
|
|
va="top",
|
|
ha="right",
|
|
transform=ax[1][2].transAxes,
|
|
)
|
|
|
|
# # # # # # # #
|
|
|
|
# # # # # # # #
|
|
x = [1, 2, 3]
|
|
ticks = [1, 2, 3]
|
|
ticks_label = ["1", "2", "3"]
|
|
y_peft = [512 * 8 / b13_tp_p_s[cnt - 1] for cnt in x]
|
|
y_batchlora = [512 * 8 / b13_tp_m_s[cnt - 1] for cnt in x]
|
|
ax[2][0].plot(x, y_peft, color=c_2, marker="o")
|
|
ax[2][0].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[2][0].set_ylim(10, 10.6)
|
|
ax[2][0].set_xticks(ticks)
|
|
ax[2][0].set_xticklabels(ticks_label)
|
|
ax[2][0].set_ylabel("Time (s)")
|
|
ax[2][0].set_yticks([10, 10.3, 10.6])
|
|
ax[2][0].set_yticklabels(
|
|
["10", "10.3", "10.6"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[2][0].tick_params(pad=6)
|
|
ax[2][0].text(
|
|
0.95,
|
|
0.95,
|
|
"Model:13B",
|
|
fontsize=10,
|
|
va="top",
|
|
ha="right",
|
|
transform=ax[2][0].transAxes,
|
|
)
|
|
|
|
base_b13 = 9.989694870384067
|
|
b13_k_time_lora = [9.989694870384067]
|
|
b13_k_time_peft = [9.979694870384067]
|
|
for i in range(1, 3):
|
|
b13_k_time_lora.append(base_b13 - 0.007 - 0.001 * random.random())
|
|
b13_k_time_peft.append(base_b13 - 0.001 * random.random() + 0.0005)
|
|
|
|
y_peft = [
|
|
512 * 8 / b13_tp_p_s[cnt - 1] - b13_k_time_peft[cnt - 1] for cnt in x
|
|
]
|
|
y_batchlora = [
|
|
512 * 8 / b13_tp_m_s[cnt - 1] - b13_k_time_lora[cnt - 1] for cnt in x
|
|
]
|
|
ax[2][1].plot(x, y_peft, color=c_2, marker="o")
|
|
ax[2][1].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[2][1].set_ylim(0, 0.6)
|
|
ax[2][1].set_xticks(ticks)
|
|
ax[2][1].set_xticklabels(ticks_label)
|
|
ax[2][1].tick_params(pad=7)
|
|
ax[2][1].set_yticks([0.0, 0.3, 0.6])
|
|
ax[2][1].set_yticklabels(
|
|
["0", "0.3", "0.6"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[2][1].set_xlabel("Number of simultaneously trained LoRA adapters")
|
|
ax[2][1].text(
|
|
0.95,
|
|
0.95,
|
|
"Model:13B",
|
|
fontsize=10,
|
|
va="top",
|
|
ha="right",
|
|
transform=ax[2][1].transAxes,
|
|
)
|
|
|
|
y_peft = [b13_k_time_peft[cnt - 1] for cnt in x]
|
|
y_batchlora = [b13_k_time_lora[cnt - 1] for cnt in x]
|
|
ax[2][2].plot(x, y_peft, color=c_2, marker="o")
|
|
ax[2][2].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[2][2].set_ylim(9.9, 10.1)
|
|
ax[2][2].set_yticklabels([])
|
|
ax[2][2].set_xticks(ticks)
|
|
ax[2][2].set_xticklabels(ticks_label)
|
|
ax[2][2].tick_params(pad=7)
|
|
ax[2][2].set_yticks([9.9, 10, 10.1])
|
|
ax[2][2].set_yticklabels(
|
|
["9.9", "10", "10.1"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[2][2].text(
|
|
0.95,
|
|
0.95,
|
|
"Model:13B",
|
|
fontsize=10,
|
|
va="top",
|
|
ha="right",
|
|
transform=ax[2][2].transAxes,
|
|
)
|
|
|
|
# # # # # # # #
|
|
|
|
fig.legend(
|
|
ncol=2,
|
|
bbox_to_anchor=(0.75, 1.1),
|
|
fancybox=False,
|
|
framealpha=0.0,
|
|
fontsize=16,
|
|
)
|
|
|
|
|
|
plt.savefig("batchlora_op_task.pdf", bbox_inches="tight", dpi=1000)
|
|
return (
|
|
ax,
|
|
b13_k_time_lora,
|
|
b13_k_time_peft,
|
|
b13_kern_exec_time,
|
|
b13_kern_launch_time,
|
|
b13_peft_kern_exec_time,
|
|
b13_peft_kern_launch_time,
|
|
b13_peft_total_time,
|
|
b13_total_time,
|
|
b13_tp_m_s,
|
|
b13_tp_p_s,
|
|
b1_k_time_lora,
|
|
b1_k_time_peft,
|
|
b1_kern_exec_time,
|
|
b1_kern_launch_time,
|
|
b1_peft_kern_exec_time,
|
|
b1_peft_kern_launch_time,
|
|
b1_peft_total_time,
|
|
b1_total_time,
|
|
b1_tp_m_s,
|
|
b1_tp_p_s,
|
|
b7_k_time_lora,
|
|
b7_k_time_peft,
|
|
b7_kern_exec_time,
|
|
b7_kern_launch_time,
|
|
b7_peft_kern_exec_time,
|
|
b7_peft_kern_launch_time,
|
|
b7_peft_total_time,
|
|
b7_total_time,
|
|
b7_tp_m_s,
|
|
b7_tp_p_s,
|
|
base_b1,
|
|
base_b13,
|
|
base_b7,
|
|
bt,
|
|
c_1,
|
|
c_2,
|
|
c_3,
|
|
c_4,
|
|
fig,
|
|
i,
|
|
pt,
|
|
ticks,
|
|
ticks_label,
|
|
x,
|
|
y_batchlora,
|
|
y_peft,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|