313 lines
9.0 KiB
Python
313 lines
9.0 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.9.17"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell
|
|
def __():
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
return np, plt
|
|
|
|
|
|
@app.cell
|
|
def __(plt):
|
|
plt.rcParams['font.family'] = 'Times New Roman'
|
|
plt.rcParams['font.size'] = 16
|
|
return
|
|
|
|
|
|
@app.cell(hide_code=True)
|
|
def __(np, plt):
|
|
x = np.arange(4)
|
|
|
|
fig, ax = plt.subplots(
|
|
figsize=(7, 3.5), ncols=3, nrows=2, layout="constrained"
|
|
)
|
|
|
|
c_1 = (230 / 255, 241 / 255, 243 / 255)
|
|
c_2 = (0, 0, 0)
|
|
c_3 = (255 / 255, 223 / 255, 146 / 255)
|
|
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
|
|
|
|
|
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
|
ticks = [1, 2, 3, 4, 5, 6, 7, 8]
|
|
ticks_label = ["1", "2", "3", "4", "5", "6", "7", "8"]
|
|
y_torch = [
|
|
20.29159868,
|
|
39.16359761,
|
|
55.65218289,
|
|
73.73416966,
|
|
88.12034672,
|
|
104.6562161,
|
|
124.9680397,
|
|
138.4435594,
|
|
]
|
|
y_batchlora = [
|
|
24.05060625,
|
|
36.4906544,
|
|
49.58459261,
|
|
62.96516142,
|
|
75.19585842,
|
|
86.59812198,
|
|
104.7908515,
|
|
116.5208559,
|
|
]
|
|
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
|
|
print(y_imporve)
|
|
ax[0][0].plot(x, y_torch, color=c_2, label="Operator without Graph Pruning", marker="o")
|
|
ax[0][0].plot(
|
|
x, y_batchlora, color=c_4, label="Operator with Graph Pruning", marker="*"
|
|
)
|
|
ax[0][0].set_ylim(0, 170)
|
|
ax[0][0].set_xticks(ticks)
|
|
ax[0][0].set_xticklabels(ticks_label)
|
|
ax[0][0].set_yticks([0, 50, 100, 150])
|
|
ax[0][0].set_yticklabels(
|
|
["0", "50", "100", "150"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[0][0].set_ylabel("Time (us)")
|
|
ax[0][0].tick_params(pad=7)
|
|
ax[0][0].text(
|
|
0.95,
|
|
0.02,
|
|
"Model:1.1B",
|
|
fontsize=14,
|
|
va="bottom",
|
|
ha="right",
|
|
transform=ax[0][0].transAxes,
|
|
)
|
|
# iax = ax[0][0].twinx()
|
|
# iax.plot(x, y_imporve, color=c_3, label="Improve", linestyle="dashdot")
|
|
# iax.set_ylim(-20, 20)
|
|
# iax.set_yticks([-20, -10, 0, 10, 20])
|
|
# iax.set_yticklabels([],
|
|
# rotation=-90, ha="center", va="center")
|
|
# iax.tick_params(pad=7)
|
|
|
|
x = [1, 2, 3, 4, 5]
|
|
ticks = [1, 2, 3, 4, 5]
|
|
ticks_label = ["1", "2", "3", "4", "5"]
|
|
y_torch = [21.9846773, 40.29510077, 59.41132316, 77.31547579, 95.81772611]
|
|
y_batchlora = [23.42831809, 37.3211666, 51.31030688, 65.47136931, 79.53268476]
|
|
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
|
|
ax[0][1].plot(x, y_torch, color=c_2, marker="o")
|
|
ax[0][1].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[0][1].set_ylim(0, 170)
|
|
ax[0][1].set_xticks(ticks)
|
|
ax[0][1].set_xticklabels(ticks_label)
|
|
ax[0][1].set_yticks([0, 50, 100, 150])
|
|
ax[0][1].set_yticklabels(
|
|
["0", "50", "100", "150"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[0][1].tick_params(pad=7)
|
|
# iax = ax[0][1].twinx()
|
|
# iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
|
|
# iax.set_ylim(-20, 20)
|
|
# iax.set_yticks([-20, -10, 0, 10, 20])
|
|
# iax.set_yticklabels([],
|
|
# rotation=-90, ha="center", va="center")
|
|
# iax.tick_params(pad=7)
|
|
ax[0][1].text(
|
|
0.95,
|
|
0.02,
|
|
"Model:7B",
|
|
fontsize=14,
|
|
va="bottom",
|
|
ha="right",
|
|
transform=ax[0][1].transAxes,
|
|
)
|
|
|
|
|
|
x = [1, 2, 3]
|
|
ticks = [1, 2, 3]
|
|
ticks_label = ["1", "2", "3"]
|
|
y_torch = [22.37562835, 42.73959994, 61.96534634]
|
|
y_batchlora = [22.80589938, 39.08431157, 54.03284915]
|
|
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
|
|
ax[0][2].plot(x, y_torch, color=c_2, marker="o")
|
|
ax[0][2].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[0][2].set_ylim(0, 85)
|
|
ax[0][2].set_xticks(ticks)
|
|
ax[0][2].set_xticklabels(ticks_label)
|
|
ax[0][2].set_yticks([0, 25, 50, 75])
|
|
ax[0][2].set_yticklabels(
|
|
["0", "25", "50", "75"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[0][2].tick_params(pad=7)
|
|
# iax = ax[0][2].twinx()
|
|
# iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
|
|
# iax.set_ylim(-20, 20)
|
|
# iax.set_yticks([-20, -10, 0, 10, 20])
|
|
# iax.set_yticklabels(["-20", "-10", "0", "10", "20"],
|
|
# rotation=-90, ha="center", va="center")
|
|
# iax.tick_params(pad=7)
|
|
ax[0][2].text(
|
|
0.95,
|
|
0.02,
|
|
"Model:13B",
|
|
fontsize=14,
|
|
va="bottom",
|
|
ha="right",
|
|
transform=ax[0][2].transAxes,
|
|
)
|
|
|
|
|
|
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
|
ticks = [1, 2, 3, 4, 5, 6, 7, 8]
|
|
ticks_label = ["1", "2", "3", "4", "5", "6", "7", "8"]
|
|
y_torch = [
|
|
3.501786362,
|
|
5.201629498,
|
|
6.926035673,
|
|
8.653747242,
|
|
10.38349666,
|
|
12.11389446,
|
|
13.84254159,
|
|
15.57186355,
|
|
]
|
|
y_batchlora = [
|
|
3.393746125,
|
|
4.998805098,
|
|
6.571097296,
|
|
8.141429216,
|
|
9.714332745,
|
|
11.28814712,
|
|
12.85960523,
|
|
14.43241727,
|
|
]
|
|
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
|
|
ax[1][0].plot(x, y_torch, color=c_2, marker="o")
|
|
ax[1][0].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[1][0].set_ylim(0, 18)
|
|
ax[1][0].set_xticks(ticks)
|
|
ax[1][0].set_xticklabels(ticks_label)
|
|
ax[1][0].set_ylabel("Peak Memory (GB)")
|
|
ax[1][0].set_yticks([0, 5, 10, 15])
|
|
ax[1][0].set_yticklabels(
|
|
["0", "5", "10", "15"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[1][0].tick_params(pad=7)
|
|
# iax = ax[1][0].twinx()
|
|
# iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
|
|
# iax.set_ylim(0, 10)
|
|
# iax.set_yticks([0, 5, 10])
|
|
# iax.set_yticklabels([],
|
|
# rotation=-90, ha="center", va="center")
|
|
# iax.tick_params(pad=7)
|
|
ax[1][0].text(
|
|
0.95,
|
|
0.02,
|
|
"Model:1.1B",
|
|
fontsize=14,
|
|
va="bottom",
|
|
ha="right",
|
|
transform=ax[1][0].transAxes,
|
|
)
|
|
|
|
|
|
x = [1, 2, 3, 4, 5]
|
|
ticks = [1, 2, 3, 4, 5]
|
|
ticks_label = ["1", "2", "3", "4", "5"]
|
|
y_torch = [11.09702569, 14.20903317, 17.3902113, 20.59260555, 23.79299152]
|
|
y_batchlora = [10.8060544, 13.69216517, 16.51687164, 19.33916381, 22.15532633]
|
|
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
|
|
ax[1][1].plot(x, y_torch, color=c_2, marker="o")
|
|
ax[1][1].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[1][1].set_ylim(10, 27)
|
|
ax[1][1].set_xticks(ticks)
|
|
ax[1][1].set_xticklabels(ticks_label)
|
|
ax[1][1].set_yticks([10, 15, 20, 25])
|
|
ax[1][1].set_yticklabels(
|
|
["10", "15", "20", "25"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[1][1].tick_params(pad=7)
|
|
# iax = ax[1][1].twinx()
|
|
# iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
|
|
# iax.set_ylim(0, 10)
|
|
# iax.set_yticks([0, 5, 10])
|
|
# iax.set_yticklabels([],
|
|
# rotation=-90, ha="center", va="center")
|
|
# iax.tick_params(pad=7)
|
|
ax[1][1].text(
|
|
0.95,
|
|
0.02,
|
|
"Model:7B",
|
|
fontsize=14,
|
|
va="bottom",
|
|
ha="right",
|
|
transform=ax[1][1].transAxes,
|
|
)
|
|
|
|
x = [1, 2, 3]
|
|
ticks = [1, 2, 3]
|
|
ticks_label = ["1", "2", "3"]
|
|
y_torch = [18.52022991, 22.66454649, 26.87934514]
|
|
y_batchlora = [18.16663067, 22.02296134, 25.78340335]
|
|
y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)]
|
|
ax[1][2].plot(x, y_torch, color=c_2, marker="o")
|
|
ax[1][2].plot(x, y_batchlora, color=c_4, marker="*")
|
|
ax[1][2].set_ylim(15, 32)
|
|
ax[1][2].set_xticks(ticks)
|
|
ax[1][2].set_xticklabels(ticks_label)
|
|
ax[1][2].set_yticks([15, 20, 25, 30])
|
|
ax[1][2].set_yticklabels(
|
|
["15", "20", "25", "30"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[1][2].tick_params(pad=7)
|
|
# iax = ax[1][2].twinx()
|
|
# iax.plot(x, y_imporve, color=c_3, linestyle="dashdot")
|
|
# iax.set_ylim(0, 10)
|
|
# iax.set_yticks([0, 5, 10])
|
|
# iax.set_yticklabels(["0", "5", "10"],
|
|
# rotation=-90, ha="center", va="center")
|
|
# iax.tick_params(pad=7)
|
|
# iax.set_ylabel("Percentage Increase")
|
|
ax[1][2].text(
|
|
0.95,
|
|
0.02,
|
|
"Model:13B",
|
|
fontsize=14,
|
|
va="bottom",
|
|
ha="right",
|
|
transform=ax[1][2].transAxes,
|
|
)
|
|
|
|
ax[1][1].set_xlabel("Number of simultaneously trained LoRA adapters")
|
|
|
|
ax[0][1].set_title(
|
|
"(a) The average time of forward and backward in the LoRA Operator",
|
|
fontsize=16,
|
|
)
|
|
ax[1][1].set_title("(b) The peak memory used of LoRA Operator", fontsize=16)
|
|
|
|
fig.legend(
|
|
ncol=3,
|
|
bbox_to_anchor=(1, 1.12),
|
|
fancybox=False,
|
|
framealpha=0.0,
|
|
fontsize=14,
|
|
)
|
|
|
|
plt.savefig("batchlora_op_cmp.pdf", bbox_inches="tight", dpi=1000)
|
|
return (
|
|
ax,
|
|
c_1,
|
|
c_2,
|
|
c_3,
|
|
c_4,
|
|
fig,
|
|
ticks,
|
|
ticks_label,
|
|
x,
|
|
y_batchlora,
|
|
y_imporve,
|
|
y_torch,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|