160 lines
3.9 KiB
Python
160 lines
3.9 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.9.17"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell
|
|
def __():
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
return np, plt
|
|
|
|
|
|
@app.cell
|
|
def __(plt):
|
|
plt.rcParams['font.family'] = 'Times New Roman'
|
|
plt.rcParams['font.size'] = 16
|
|
return
|
|
|
|
|
|
@app.cell
|
|
def __(np, plt):
|
|
x = np.arange(4)
|
|
lorapp = np.array([10748, 2363.89, 1280.54])
|
|
tp = np.array([5752.14, 1500, 875])
|
|
fsdp = np.array([6151.91, 1750, 0])
|
|
gpipe = np.array([4599.87, 1284.27, 723.21])
|
|
|
|
fig, ax = plt.subplots(figsize=(7, 2), ncols=3, layout="constrained")
|
|
|
|
space_width = 3 / 17
|
|
bar_width = 3 * space_width
|
|
|
|
c_1 = (139 / 255, 0 / 255, 0 / 255)
|
|
c_2 = (0, 0, 0)
|
|
c_3 = (191 / 255, 191 / 255, 191 / 255)
|
|
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
|
|
|
ticks = [
|
|
space_width,
|
|
2 * space_width + bar_width,
|
|
3 * space_width + 2 * bar_width,
|
|
4 * space_width + 3 * bar_width,
|
|
]
|
|
ticks_label = ["1F1B", "FSDP", "TP", "LoRAPP"]
|
|
|
|
ax[0].bar(space_width, gpipe[0], bar_width, label="1F1B", color=c_1)
|
|
ax[0].bar(2 * space_width + bar_width, tp[0], bar_width, label="TP", color=c_2)
|
|
ax[0].bar(
|
|
3 * space_width + 2 * bar_width,
|
|
fsdp[0],
|
|
bar_width,
|
|
label="FSDP",
|
|
color=c_3,
|
|
)
|
|
ax[0].bar(
|
|
4 * space_width + 3 * bar_width,
|
|
lorapp[0],
|
|
bar_width,
|
|
label="LoRAPP",
|
|
color=c_4,
|
|
)
|
|
ax[0].set_ylabel("Throughput (tokens/s)", fontsize=12)
|
|
ax[0].get_xaxis().set_visible(False)
|
|
|
|
ax[0].set_ylim([0, 12000])
|
|
ax[0].set_yticks([1000, 4000, 7000, 10000])
|
|
ax[0].set_yticklabels(
|
|
["1k", "4k", "7k", "10k"], rotation=90, ha="center", va="center"
|
|
)
|
|
ax[0].tick_params(bottom=False, labelsize=13, pad=7)
|
|
ax[0].set_title("(a) TinyLlama-1.1B", fontsize=16)
|
|
|
|
ax[1].bar(space_width, gpipe[1], bar_width, label="LoRAPP", color=c_1)
|
|
ax[1].bar(2 * space_width + bar_width, tp[1], bar_width, label="TP", color=c_2)
|
|
ax[1].bar(
|
|
3 * space_width + 2 * bar_width,
|
|
fsdp[1],
|
|
bar_width,
|
|
label="FSDP",
|
|
color=c_3,
|
|
)
|
|
ax[1].bar(
|
|
4 * space_width + 3 * bar_width,
|
|
lorapp[1],
|
|
bar_width,
|
|
label="1F1B",
|
|
color=c_4,
|
|
)
|
|
ax[1].get_xaxis().set_visible(False)
|
|
ax[1].set_ylim([0, 2500])
|
|
ax[1].set_yticks([1000, 2000])
|
|
ax[1].set_yticklabels(["1000", "2000"], rotation=90, ha="center", va="center")
|
|
ax[1].tick_params(bottom=False, labelsize=13, pad=7)
|
|
ax[1].set_title("(b) Llama-2-7B", fontsize=16)
|
|
|
|
ax[2].bar(space_width, gpipe[2], bar_width, label="1F1B", color=c_1)
|
|
ax[2].bar(2 * space_width + bar_width, tp[2], bar_width, label="TP", color=c_2)
|
|
ax[2].bar(
|
|
3 * space_width + 2 * bar_width,
|
|
fsdp[2],
|
|
bar_width,
|
|
label="FSDP",
|
|
color=c_3,
|
|
)
|
|
ax[2].bar(
|
|
4 * space_width + 3 * bar_width,
|
|
lorapp[2],
|
|
bar_width,
|
|
label="LoRAPP",
|
|
color=c_4,
|
|
)
|
|
# 隐藏x轴
|
|
ax[2].get_xaxis().set_visible(False)
|
|
ax[2].set_ylim([0, 1500])
|
|
ax[2].set_yticks([400, 800, 1200])
|
|
ax[2].set_yticklabels(
|
|
["400", "800", "1200"], rotation=90, ha="center", va="center"
|
|
)
|
|
ax[2].tick_params(bottom=False, labelsize=13, pad=7)
|
|
ax[2].text(
|
|
ticks[2],
|
|
100,
|
|
"OOM",
|
|
ha="center",
|
|
va="center",
|
|
color="r",
|
|
style="italic",
|
|
fontsize=14,
|
|
)
|
|
ax[2].set_title("(c) Llama-2-13B", fontsize=16)
|
|
|
|
plt.tight_layout()
|
|
plt.legend(
|
|
ncol=4, loc="upper center", bbox_to_anchor=(-0.8, 1.50), frameon=False
|
|
)
|
|
|
|
plt.savefig("pp_cmp_tp.pdf", bbox_inches="tight", dpi=1000)
|
|
return (
|
|
ax,
|
|
bar_width,
|
|
c_1,
|
|
c_2,
|
|
c_3,
|
|
c_4,
|
|
fig,
|
|
fsdp,
|
|
gpipe,
|
|
lorapp,
|
|
space_width,
|
|
ticks,
|
|
ticks_label,
|
|
tp,
|
|
x,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|