paper_note/mlora/batchlora_cmp.py
2025-03-05 20:38:41 +08:00

134 lines
3.9 KiB
Python

import marimo
__generated_with = "0.7.0"
app = marimo.App(width="medium")
@app.cell
def __():
import matplotlib.pyplot as plt
import numpy as np
import random
return np, plt, random
@app.cell
def __(plt):
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 16
return
@app.cell(hide_code=True)
def __(plt):
fig, ax = plt.subplots(figsize=(7, 2.4), ncols=3, layout="constrained")
space_width = 3 / 22
bar_width = 3 * space_width
c_1 = (139 / 255, 0 / 255, 0 / 255)
c_2 = (0, 0, 0)
c_3 = (191 / 255, 191 / 255, 191 / 255)
c_4 = (230 / 255, 109 / 255, 104 / 255)
x_ticks = [
bar_width,
space_width + 3 * bar_width,
2 * space_width + 5 * bar_width,
]
x_ticks_label = ["1.1B", "7B", "13B"]
ax[0].bar(space_width, 512 * 8 / 2812, bar_width, color=c_2)
ax[0].bar(space_width + bar_width, 512 * 8 / 3054, bar_width, color=c_4)
ax[0].bar(
2 * space_width + 2 * bar_width,
512 * 8 / 690.5880666,
bar_width,
color=c_2,
)
ax[0].bar(
2 * space_width + 3 * bar_width,
512 * 8 / 727.7364507,
bar_width,
color=c_4,
)
ax[0].bar(
3 * space_width + 4 * bar_width,
512 * 8 / 396.4798927,
bar_width,
color=c_2,
)
ax[0].bar(
3 * space_width + 5 * bar_width,
512 * 8 / 405.9223082,
bar_width,
color=c_4,
)
ax[0].set_xticks(x_ticks)
ax[0].set_xticklabels(x_ticks_label, ha="center", va="center")
ax[0].tick_params(bottom=False, labelsize=14, pad=7)
ax[0].set_ylabel("Time (s)", fontsize=16)
ax[0].set_title("(a) Training time", fontsize=16, pad=22)
ax[1].bar(space_width, 10.1, bar_width, color=c_2)
ax[1].bar(space_width + bar_width, 2.4, bar_width, color=c_4)
ax[1].bar(2 * space_width + 2 * bar_width, 7.5, bar_width, color=c_2)
ax[1].bar(2 * space_width + 3 * bar_width, 2.1, bar_width, color=c_4)
ax[1].bar(3 * space_width + 4 * bar_width, 3.9, bar_width, color=c_2)
ax[1].bar(3 * space_width + 5 * bar_width, 1.1, bar_width, color=c_4)
ax[1].set_xticks(x_ticks)
ax[1].set_xticklabels(x_ticks_label, ha="center", va="center")
ax[1].set_ylabel("Percentage (%)", fontsize=16)
ax[1].tick_params(bottom=False, labelsize=14, pad=7)
ax[1].set_title("(b) Proportion of \nkernel launch time", fontsize=16)
peft_kern_time = [12969965266, 56659466215, 96280798449]
batchlora_kern_time = [12969965266, 55606549735, 96080798449]
pt = [1.3109530583214795, 5.493869969292716, 9.928009143652595]
bt = [1.309003274394237, 5.493329344922422, 9.919060664229837]
ax[2].bar(space_width, pt[0], bar_width, label="PEFT", color=c_2)
ax[2].bar(
space_width + bar_width, bt[0], bar_width, label="BatchLoRA", color=c_4
)
ax[2].bar(2 * space_width + 2 * bar_width, pt[1], bar_width, color=c_2)
ax[2].bar(2 * space_width + 3 * bar_width, bt[1], bar_width, color=c_4)
ax[2].bar(3 * space_width + 4 * bar_width, pt[2], bar_width, color=c_2)
ax[2].bar(3 * space_width + 5 * bar_width, bt[2], bar_width, color=c_4)
ax[2].set_xticks(x_ticks)
ax[2].set_xticklabels(x_ticks_label, ha="center", va="center")
ax[2].set_ylabel("Time (s)", fontsize=16)
ax[2].tick_params(bottom=False, labelsize=14, pad=7)
ax[2].set_title("(c) Kernel execution time", fontsize=16, pad=22)
fig.legend(
ncol=2,
bbox_to_anchor=(0.8, 1.2),
fancybox=False,
framealpha=0.0,
fontsize=16,
)
# plt.savefig("batchlora_cmp.pdf", bbox_inches="tight", dpi=1000)
return (
ax,
bar_width,
batchlora_kern_time,
bt,
c_1,
c_2,
c_3,
c_4,
fig,
peft_kern_time,
pt,
space_width,
x_ticks,
x_ticks_label,
)
if __name__ == "__main__":
app.run()