151 lines
4.4 KiB
Python
151 lines
4.4 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.9.17"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell
|
|
def __():
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
return np, plt
|
|
|
|
|
|
@app.cell
|
|
def __(plt):
|
|
plt.rcParams['font.family'] = 'Times New Roman'
|
|
plt.rcParams['font.size'] = 16
|
|
return
|
|
|
|
|
|
@app.cell
|
|
def __(np, plt):
|
|
x = np.arange(4)
|
|
|
|
fig, ax = plt.subplots(figsize=(7, 2), ncols=3, layout="constrained")
|
|
|
|
c_1 = (139 / 255, 0 / 255, 0 / 255)
|
|
c_2 = (0, 0, 0)
|
|
c_3 = (191 / 255, 191 / 255, 191 / 255)
|
|
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
|
x_b1 = [1, 2, 3, 4]
|
|
b1_throughput_mLoRA = [4600.45, 8664.91, 10118.36, 10184.44]
|
|
b1_throughput_tp = [5752.14, 5749.34, 5756.78, 5758.32]
|
|
b1_throughput_fsdp = [6151.91, 6141.73, 6161.23, 6153.93]
|
|
b1_throughput_gpipe = [4599.87, 4610.19, 4592.17, 4601.18]
|
|
|
|
x_b7 = [1, 2, 3, 4]
|
|
b7_throughput_mLoRA = [1274.87, 2250.46, 2362.69, 2363.89]
|
|
b7_throughput_tp = [1614.18, 1610.26, 1620.07, 1613.34]
|
|
b7_throughput_fsdp = [1695.37, 1705.97, 1686.05, 1693.45]
|
|
b7_throughput_gpipe = [1284.27, 1273.89, 1272.14, 1279.64]
|
|
|
|
x_b13 = [1, 2, 3, 4]
|
|
b13_throughput_mLoRA = [723.21, 1280.54, 1286.54, 1282.54]
|
|
b13_throughput_tp = [875, 877, 870, 878]
|
|
b13_throughput_fsdp = [0, 0, 0, 0] # for the error
|
|
b13_throughput_gpipe = [723.21, 719.21, 726.21, 723.21]
|
|
|
|
ax[0].set_ylabel("Throughput (tokens/s)", fontsize=16)
|
|
|
|
|
|
ax[0].plot(x_b1, b1_throughput_fsdp, color=c_3, marker="v")
|
|
ax[0].plot(x_b1, b1_throughput_tp, color=c_2, marker="o")
|
|
ax[0].plot(x_b1, b1_throughput_gpipe, color=c_1, marker="^")
|
|
ax[0].plot(x_b1, b1_throughput_mLoRA, color=c_4, marker="*")
|
|
ax[0].set_ylim(3000, 12000)
|
|
ax[0].set_xticks([1, 2, 3, 4])
|
|
ax[0].set_xticklabels(["1", "2", "3", "4"])
|
|
ax[0].set_yticks([5000, 7000, 9000, 11000])
|
|
ax[0].set_yticklabels(
|
|
["5k", "7k", "9k", "11k"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[0].tick_params(pad=7)
|
|
ax[0].set_title("(a) TinyLlama-1.1B", fontsize=16)
|
|
|
|
ax[1].plot(x_b7, b7_throughput_gpipe, color=c_1, label="1F1B", marker="^")
|
|
ax[1].plot(x_b7, b7_throughput_tp, color=c_2, label="TP", marker="o")
|
|
ax[1].plot(x_b7, b7_throughput_fsdp, color=c_3, label="FSDP", marker="v")
|
|
ax[1].plot(x_b7, b7_throughput_mLoRA, color=c_4, label="LoRAPP", marker="*")
|
|
ax[1].set_ylim(1000, 2500)
|
|
ax[1].set_xticks([1, 2, 3, 4])
|
|
ax[1].set_xticklabels(["1", "2", "3", "4"])
|
|
ax[1].set_yticks([1200, 1700, 2200])
|
|
ax[1].set_yticklabels(
|
|
["1200", "1700", "2200"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[1].tick_params(pad=7)
|
|
ax[1].set_title("(b) Llama-2-7B", fontsize=16)
|
|
ax[1].set_xlabel("Number of simultaneously trained LoRA adapters", fontsize=16)
|
|
|
|
ax[2].plot(
|
|
x_b13,
|
|
b13_throughput_fsdp,
|
|
color=c_3,
|
|
marker="x",
|
|
markerfacecolor="r",
|
|
markeredgecolor="r",
|
|
)
|
|
ax[2].plot(x_b13, b13_throughput_tp, color=c_2, marker="o")
|
|
ax[2].plot(x_b13, b13_throughput_gpipe, color=c_1, marker="^")
|
|
ax[2].plot(x_b13, b13_throughput_mLoRA, color=c_4, marker="*")
|
|
ax[2].set_ylim(400, 1500)
|
|
ax[2].set_xticks([1, 2, 3, 4])
|
|
ax[2].set_xticklabels(["1", "2", "3", "4"])
|
|
ax[2].set_yticks([600, 950, 1300])
|
|
ax[2].set_yticklabels(
|
|
["650", "950", "1300"], rotation=90, ha="center", va="center", fontsize=12
|
|
)
|
|
ax[2].tick_params(pad=7)
|
|
ax[2].set_title("(c) Llama-2-13B", fontsize=16)
|
|
ax[2].text(
|
|
0.95,
|
|
0.05,
|
|
"FSDP : OOM",
|
|
fontsize=14,
|
|
va="bottom",
|
|
ha="right",
|
|
transform=ax[2].transAxes,
|
|
color="r",
|
|
style="italic",
|
|
)
|
|
|
|
fig.legend(
|
|
ncol=4,
|
|
bbox_to_anchor=(0.97, 1.2),
|
|
fancybox=False,
|
|
framealpha=0.0,
|
|
fontsize=16,
|
|
)
|
|
|
|
# plt.show()
|
|
plt.savefig("lorapp-task.pdf", bbox_inches="tight", dpi=1000)
|
|
return (
|
|
ax,
|
|
b13_throughput_fsdp,
|
|
b13_throughput_gpipe,
|
|
b13_throughput_mLoRA,
|
|
b13_throughput_tp,
|
|
b1_throughput_fsdp,
|
|
b1_throughput_gpipe,
|
|
b1_throughput_mLoRA,
|
|
b1_throughput_tp,
|
|
b7_throughput_fsdp,
|
|
b7_throughput_gpipe,
|
|
b7_throughput_mLoRA,
|
|
b7_throughput_tp,
|
|
c_1,
|
|
c_2,
|
|
c_3,
|
|
c_4,
|
|
fig,
|
|
x,
|
|
x_b1,
|
|
x_b13,
|
|
x_b7,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|