paper_note/mlora/batchlora_op_cmp.py
2025-03-05 20:38:41 +08:00

313 lines
9.0 KiB
Python

import marimo
__generated_with = "0.9.17"
app = marimo.App(width="medium")
@app.cell
def __():
import matplotlib.pyplot as plt
import numpy as np
return np, plt
@app.cell
def __(plt):
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 16
return
@app.cell(hide_code=True)
def __(np, plt):
x = np.arange(4)
fig, ax = plt.subplots(
figsize=(7, 3.5), ncols=3, nrows=2, layout="constrained"
)
c_1 = (230 / 255, 241 / 255, 243 / 255)
c_2 = (0, 0, 0)
c_3 = (255 / 255, 223 / 255, 146 / 255)
c_4 = (230 / 255, 109 / 255, 104 / 255)
x = [1, 2, 3, 4, 5, 6, 7, 8]
ticks = [1, 2, 3, 4, 5, 6, 7, 8]
ticks_label = ["1", "2", "3", "4", "5", "6", "7", "8"]
y_torch = [
20.29159868,
39.16359761,
55.65218289,
73.73416966,
88.12034672,
104.6562161,
124.9680397,
138.4435594,
]
y_batchlora = [
24.05060625,
36.4906544,
49.58459261,
62.96516142,
75.19585842,
86.59812198,
104.7908515,
116.5208559,
]
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
print(y_imporve)
ax[0][0].plot(x, y_torch, color=c_2, label="Operator without Graph Pruning", marker="o")
ax[0][0].plot(
x, y_batchlora, color=c_4, label="Operator with Graph Pruning", marker="*"
)
ax[0][0].set_ylim(0, 170)
ax[0][0].set_xticks(ticks)
ax[0][0].set_xticklabels(ticks_label)
ax[0][0].set_yticks([0, 50, 100, 150])
ax[0][0].set_yticklabels(
["0", "50", "100", "150"], rotation=90, ha="center", va="center", fontsize=12
)
ax[0][0].set_ylabel("Time (us)")
ax[0][0].tick_params(pad=7)
ax[0][0].text(
0.95,
0.02,
"Model:1.1B",
fontsize=14,
va="bottom",
ha="right",
transform=ax[0][0].transAxes,
)
# iax = ax[0][0].twinx()
# iax.plot(x, y_imporve, color=c_3, label="Improve", linestyle="dashdot")
# iax.set_ylim(-20, 20)
# iax.set_yticks([-20, -10, 0, 10, 20])
# iax.set_yticklabels([],
# rotation=-90, ha="center", va="center")
# iax.tick_params(pad=7)
x = [1, 2, 3, 4, 5]
ticks = [1, 2, 3, 4, 5]
ticks_label = ["1", "2", "3", "4", "5"]
y_torch = [21.9846773, 40.29510077, 59.41132316, 77.31547579, 95.81772611]
y_batchlora = [23.42831809, 37.3211666, 51.31030688, 65.47136931, 79.53268476]
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
ax[0][1].plot(x, y_torch, color=c_2, marker="o")
ax[0][1].plot(x, y_batchlora, color=c_4, marker="*")
ax[0][1].set_ylim(0, 170)
ax[0][1].set_xticks(ticks)
ax[0][1].set_xticklabels(ticks_label)
ax[0][1].set_yticks([0, 50, 100, 150])
ax[0][1].set_yticklabels(
["0", "50", "100", "150"], rotation=90, ha="center", va="center", fontsize=12
)
ax[0][1].tick_params(pad=7)
# iax = ax[0][1].twinx()
# iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
# iax.set_ylim(-20, 20)
# iax.set_yticks([-20, -10, 0, 10, 20])
# iax.set_yticklabels([],
# rotation=-90, ha="center", va="center")
# iax.tick_params(pad=7)
ax[0][1].text(
0.95,
0.02,
"Model:7B",
fontsize=14,
va="bottom",
ha="right",
transform=ax[0][1].transAxes,
)
x = [1, 2, 3]
ticks = [1, 2, 3]
ticks_label = ["1", "2", "3"]
y_torch = [22.37562835, 42.73959994, 61.96534634]
y_batchlora = [22.80589938, 39.08431157, 54.03284915]
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
ax[0][2].plot(x, y_torch, color=c_2, marker="o")
ax[0][2].plot(x, y_batchlora, color=c_4, marker="*")
ax[0][2].set_ylim(0, 85)
ax[0][2].set_xticks(ticks)
ax[0][2].set_xticklabels(ticks_label)
ax[0][2].set_yticks([0, 25, 50, 75])
ax[0][2].set_yticklabels(
["0", "25", "50", "75"], rotation=90, ha="center", va="center", fontsize=12
)
ax[0][2].tick_params(pad=7)
# iax = ax[0][2].twinx()
# iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
# iax.set_ylim(-20, 20)
# iax.set_yticks([-20, -10, 0, 10, 20])
# iax.set_yticklabels(["-20", "-10", "0", "10", "20"],
# rotation=-90, ha="center", va="center")
# iax.tick_params(pad=7)
ax[0][2].text(
0.95,
0.02,
"Model:13B",
fontsize=14,
va="bottom",
ha="right",
transform=ax[0][2].transAxes,
)
x = [1, 2, 3, 4, 5, 6, 7, 8]
ticks = [1, 2, 3, 4, 5, 6, 7, 8]
ticks_label = ["1", "2", "3", "4", "5", "6", "7", "8"]
y_torch = [
3.501786362,
5.201629498,
6.926035673,
8.653747242,
10.38349666,
12.11389446,
13.84254159,
15.57186355,
]
y_batchlora = [
3.393746125,
4.998805098,
6.571097296,
8.141429216,
9.714332745,
11.28814712,
12.85960523,
14.43241727,
]
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
ax[1][0].plot(x, y_torch, color=c_2, marker="o")
ax[1][0].plot(x, y_batchlora, color=c_4, marker="*")
ax[1][0].set_ylim(0, 18)
ax[1][0].set_xticks(ticks)
ax[1][0].set_xticklabels(ticks_label)
ax[1][0].set_ylabel("Peak Memory (GB)")
ax[1][0].set_yticks([0, 5, 10, 15])
ax[1][0].set_yticklabels(
["0", "5", "10", "15"], rotation=90, ha="center", va="center", fontsize=12
)
ax[1][0].tick_params(pad=7)
# iax = ax[1][0].twinx()
# iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
# iax.set_ylim(0, 10)
# iax.set_yticks([0, 5, 10])
# iax.set_yticklabels([],
# rotation=-90, ha="center", va="center")
# iax.tick_params(pad=7)
ax[1][0].text(
0.95,
0.02,
"Model:1.1B",
fontsize=14,
va="bottom",
ha="right",
transform=ax[1][0].transAxes,
)
x = [1, 2, 3, 4, 5]
ticks = [1, 2, 3, 4, 5]
ticks_label = ["1", "2", "3", "4", "5"]
y_torch = [11.09702569, 14.20903317, 17.3902113, 20.59260555, 23.79299152]
y_batchlora = [10.8060544, 13.69216517, 16.51687164, 19.33916381, 22.15532633]
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
ax[1][1].plot(x, y_torch, color=c_2, marker="o")
ax[1][1].plot(x, y_batchlora, color=c_4, marker="*")
ax[1][1].set_ylim(10, 27)
ax[1][1].set_xticks(ticks)
ax[1][1].set_xticklabels(ticks_label)
ax[1][1].set_yticks([10, 15, 20, 25])
ax[1][1].set_yticklabels(
["10", "15", "20", "25"], rotation=90, ha="center", va="center", fontsize=12
)
ax[1][1].tick_params(pad=7)
# iax = ax[1][1].twinx()
# iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
# iax.set_ylim(0, 10)
# iax.set_yticks([0, 5, 10])
# iax.set_yticklabels([],
# rotation=-90, ha="center", va="center")
# iax.tick_params(pad=7)
ax[1][1].text(
0.95,
0.02,
"Model:7B",
fontsize=14,
va="bottom",
ha="right",
transform=ax[1][1].transAxes,
)
x = [1, 2, 3]
ticks = [1, 2, 3]
ticks_label = ["1", "2", "3"]
y_torch = [18.52022991, 22.66454649, 26.87934514]
y_batchlora = [18.16663067, 22.02296134, 25.78340335]
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
ax[1][2].plot(x, y_torch, color=c_2, marker="o")
ax[1][2].plot(x, y_batchlora, color=c_4, marker="*")
ax[1][2].set_ylim(15, 32)
ax[1][2].set_xticks(ticks)
ax[1][2].set_xticklabels(ticks_label)
ax[1][2].set_yticks([15, 20, 25, 30])
ax[1][2].set_yticklabels(
["15", "20", "25", "30"], rotation=90, ha="center", va="center", fontsize=12
)
ax[1][2].tick_params(pad=7)
# iax = ax[1][2].twinx()
# iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
# iax.set_ylim(0, 10)
# iax.set_yticks([0, 5, 10])
# iax.set_yticklabels(["0", "5", "10"],
# rotation=-90, ha="center", va="center")
# iax.tick_params(pad=7)
# iax.set_ylabel("Percentage Increase")
ax[1][2].text(
0.95,
0.02,
"Model:13B",
fontsize=14,
va="bottom",
ha="right",
transform=ax[1][2].transAxes,
)
ax[1][1].set_xlabel("Number of simultaneously trained LoRA adapters")
ax[0][1].set_title(
"(a) The average time of forward and backward in the LoRA Operator",
fontsize=16,
)
ax[1][1].set_title("(b) The peak memory used of LoRA Operator", fontsize=16)
fig.legend(
ncol=3,
bbox_to_anchor=(1, 1.12),
fancybox=False,
framealpha=0.0,
fontsize=14,
)
plt.savefig("batchlora_op_cmp.pdf", bbox_inches="tight", dpi=1000)
return (
ax,
c_1,
c_2,
c_3,
c_4,
fig,
ticks,
ticks_label,
x,
y_batchlora,
y_imporve,
y_torch,
)
if __name__ == "__main__":
app.run()