paper_note/mlora/pp_cmp_tp.py
2025-03-05 20:38:41 +08:00

160 lines
3.9 KiB
Python

import marimo
__generated_with = "0.9.17"
app = marimo.App(width="medium")
@app.cell
def __():
import matplotlib.pyplot as plt
import numpy as np
return np, plt
@app.cell
def __(plt):
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 16
return
@app.cell
def __(np, plt):
x = np.arange(4)
lorapp = np.array([10748, 2363.89, 1280.54])
tp = np.array([5752.14, 1500, 875])
fsdp = np.array([6151.91, 1750, 0])
gpipe = np.array([4599.87, 1284.27, 723.21])
fig, ax = plt.subplots(figsize=(7, 2), ncols=3, layout="constrained")
space_width = 3 / 17
bar_width = 3 * space_width
c_1 = (139 / 255, 0 / 255, 0 / 255)
c_2 = (0, 0, 0)
c_3 = (191 / 255, 191 / 255, 191 / 255)
c_4 = (230 / 255, 109 / 255, 104 / 255)
ticks = [
space_width,
2 * space_width + bar_width,
3 * space_width + 2 * bar_width,
4 * space_width + 3 * bar_width,
]
ticks_label = ["1F1B", "FSDP", "TP", "LoRAPP"]
ax[0].bar(space_width, gpipe[0], bar_width, label="1F1B", color=c_1)
ax[0].bar(2 * space_width + bar_width, tp[0], bar_width, label="TP", color=c_2)
ax[0].bar(
3 * space_width + 2 * bar_width,
fsdp[0],
bar_width,
label="FSDP",
color=c_3,
)
ax[0].bar(
4 * space_width + 3 * bar_width,
lorapp[0],
bar_width,
label="LoRAPP",
color=c_4,
)
ax[0].set_ylabel("Throughput (tokens/s)", fontsize=12)
ax[0].get_xaxis().set_visible(False)
ax[0].set_ylim([0, 12000])
ax[0].set_yticks([1000, 4000, 7000, 10000])
ax[0].set_yticklabels(
["1k", "4k", "7k", "10k"], rotation=90, ha="center", va="center"
)
ax[0].tick_params(bottom=False, labelsize=13, pad=7)
ax[0].set_title("(a) TinyLlama-1.1B", fontsize=16)
ax[1].bar(space_width, gpipe[1], bar_width, label="LoRAPP", color=c_1)
ax[1].bar(2 * space_width + bar_width, tp[1], bar_width, label="TP", color=c_2)
ax[1].bar(
3 * space_width + 2 * bar_width,
fsdp[1],
bar_width,
label="FSDP",
color=c_3,
)
ax[1].bar(
4 * space_width + 3 * bar_width,
lorapp[1],
bar_width,
label="1F1B",
color=c_4,
)
ax[1].get_xaxis().set_visible(False)
ax[1].set_ylim([0, 2500])
ax[1].set_yticks([1000, 2000])
ax[1].set_yticklabels(["1000", "2000"], rotation=90, ha="center", va="center")
ax[1].tick_params(bottom=False, labelsize=13, pad=7)
ax[1].set_title("(b) Llama-2-7B", fontsize=16)
ax[2].bar(space_width, gpipe[2], bar_width, label="1F1B", color=c_1)
ax[2].bar(2 * space_width + bar_width, tp[2], bar_width, label="TP", color=c_2)
ax[2].bar(
3 * space_width + 2 * bar_width,
fsdp[2],
bar_width,
label="FSDP",
color=c_3,
)
ax[2].bar(
4 * space_width + 3 * bar_width,
lorapp[2],
bar_width,
label="LoRAPP",
color=c_4,
)
# 隐藏x轴
ax[2].get_xaxis().set_visible(False)
ax[2].set_ylim([0, 1500])
ax[2].set_yticks([400, 800, 1200])
ax[2].set_yticklabels(
["400", "800", "1200"], rotation=90, ha="center", va="center"
)
ax[2].tick_params(bottom=False, labelsize=13, pad=7)
ax[2].text(
ticks[2],
100,
"OOM",
ha="center",
va="center",
color="r",
style="italic",
fontsize=14,
)
ax[2].set_title("(c) Llama-2-13B", fontsize=16)
plt.tight_layout()
plt.legend(
ncol=4, loc="upper center", bbox_to_anchor=(-0.8, 1.50), frameon=False
)
plt.savefig("pp_cmp_tp.pdf", bbox_inches="tight", dpi=1000)
return (
ax,
bar_width,
c_1,
c_2,
c_3,
c_4,
fig,
fsdp,
gpipe,
lorapp,
space_width,
ticks,
ticks_label,
tp,
x,
)
if __name__ == "__main__":
app.run()