134 lines
3.9 KiB
Python
134 lines
3.9 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.7.0"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell
|
|
def __():
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import random
|
|
return np, plt, random
|
|
|
|
|
|
@app.cell
|
|
def __(plt):
|
|
plt.rcParams['font.family'] = 'Times New Roman'
|
|
plt.rcParams['font.size'] = 16
|
|
return
|
|
|
|
|
|
@app.cell(hide_code=True)
|
|
def __(plt):
|
|
fig, ax = plt.subplots(figsize=(7, 2.4), ncols=3, layout="constrained")
|
|
|
|
space_width = 3 / 22
|
|
bar_width = 3 * space_width
|
|
|
|
c_1 = (139 / 255, 0 / 255, 0 / 255)
|
|
c_2 = (0, 0, 0)
|
|
c_3 = (191 / 255, 191 / 255, 191 / 255)
|
|
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
|
|
|
x_ticks = [
|
|
bar_width,
|
|
space_width + 3 * bar_width,
|
|
2 * space_width + 5 * bar_width,
|
|
]
|
|
x_ticks_label = ["1.1B", "7B", "13B"]
|
|
|
|
ax[0].bar(space_width, 512 * 8 / 2812, bar_width, color=c_2)
|
|
ax[0].bar(space_width + bar_width, 512 * 8 / 3054, bar_width, color=c_4)
|
|
ax[0].bar(
|
|
2 * space_width + 2 * bar_width,
|
|
512 * 8 / 690.5880666,
|
|
bar_width,
|
|
color=c_2,
|
|
)
|
|
ax[0].bar(
|
|
2 * space_width + 3 * bar_width,
|
|
512 * 8 / 727.7364507,
|
|
bar_width,
|
|
color=c_4,
|
|
)
|
|
ax[0].bar(
|
|
3 * space_width + 4 * bar_width,
|
|
512 * 8 / 396.4798927,
|
|
bar_width,
|
|
color=c_2,
|
|
)
|
|
ax[0].bar(
|
|
3 * space_width + 5 * bar_width,
|
|
512 * 8 / 405.9223082,
|
|
bar_width,
|
|
color=c_4,
|
|
)
|
|
ax[0].set_xticks(x_ticks)
|
|
ax[0].set_xticklabels(x_ticks_label, ha="center", va="center")
|
|
ax[0].tick_params(bottom=False, labelsize=14, pad=7)
|
|
ax[0].set_ylabel("Time (s)", fontsize=16)
|
|
ax[0].set_title("(a) Training time", fontsize=16, pad=22)
|
|
|
|
ax[1].bar(space_width, 10.1, bar_width, color=c_2)
|
|
ax[1].bar(space_width + bar_width, 2.4, bar_width, color=c_4)
|
|
ax[1].bar(2 * space_width + 2 * bar_width, 7.5, bar_width, color=c_2)
|
|
ax[1].bar(2 * space_width + 3 * bar_width, 2.1, bar_width, color=c_4)
|
|
ax[1].bar(3 * space_width + 4 * bar_width, 3.9, bar_width, color=c_2)
|
|
ax[1].bar(3 * space_width + 5 * bar_width, 1.1, bar_width, color=c_4)
|
|
ax[1].set_xticks(x_ticks)
|
|
ax[1].set_xticklabels(x_ticks_label, ha="center", va="center")
|
|
ax[1].set_ylabel("Percentage (%)", fontsize=16)
|
|
ax[1].tick_params(bottom=False, labelsize=14, pad=7)
|
|
ax[1].set_title("(b) Proportion of \nkernel launch time", fontsize=16)
|
|
|
|
peft_kern_time = [12969965266, 56659466215, 96280798449]
|
|
batchlora_kern_time = [12969965266, 55606549735, 96080798449]
|
|
|
|
pt = [1.3109530583214795, 5.493869969292716, 9.928009143652595]
|
|
bt = [1.309003274394237, 5.493329344922422, 9.919060664229837]
|
|
|
|
ax[2].bar(space_width, pt[0], bar_width, label="PEFT", color=c_2)
|
|
ax[2].bar(
|
|
space_width + bar_width, bt[0], bar_width, label="BatchLoRA", color=c_4
|
|
)
|
|
ax[2].bar(2 * space_width + 2 * bar_width, pt[1], bar_width, color=c_2)
|
|
ax[2].bar(2 * space_width + 3 * bar_width, bt[1], bar_width, color=c_4)
|
|
ax[2].bar(3 * space_width + 4 * bar_width, pt[2], bar_width, color=c_2)
|
|
ax[2].bar(3 * space_width + 5 * bar_width, bt[2], bar_width, color=c_4)
|
|
ax[2].set_xticks(x_ticks)
|
|
ax[2].set_xticklabels(x_ticks_label, ha="center", va="center")
|
|
ax[2].set_ylabel("Time (s)", fontsize=16)
|
|
ax[2].tick_params(bottom=False, labelsize=14, pad=7)
|
|
ax[2].set_title("(c) Kernel execution time", fontsize=16, pad=22)
|
|
|
|
fig.legend(
|
|
ncol=2,
|
|
bbox_to_anchor=(0.8, 1.2),
|
|
fancybox=False,
|
|
framealpha=0.0,
|
|
fontsize=16,
|
|
)
|
|
|
|
# plt.savefig("batchlora_cmp.pdf", bbox_inches="tight", dpi=1000)
|
|
return (
|
|
ax,
|
|
bar_width,
|
|
batchlora_kern_time,
|
|
bt,
|
|
c_1,
|
|
c_2,
|
|
c_3,
|
|
c_4,
|
|
fig,
|
|
peft_kern_time,
|
|
pt,
|
|
space_width,
|
|
x_ticks,
|
|
x_ticks_label,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|