paper_note/mlora/lorapp_task.py
2025-03-05 20:38:41 +08:00

151 lines
4.4 KiB
Python

import marimo
__generated_with = "0.9.17"
app = marimo.App(width="medium")
@app.cell
def __():
import matplotlib.pyplot as plt
import numpy as np
return np, plt
@app.cell
def __(plt):
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 16
return
@app.cell
def __(np, plt):
x = np.arange(4)
fig, ax = plt.subplots(figsize=(7, 2), ncols=3, layout="constrained")
c_1 = (139 / 255, 0 / 255, 0 / 255)
c_2 = (0, 0, 0)
c_3 = (191 / 255, 191 / 255, 191 / 255)
c_4 = (230 / 255, 109 / 255, 104 / 255)
x_b1 = [1, 2, 3, 4]
b1_throughput_mLoRA = [4600.45, 8664.91, 10118.36, 10184.44]
b1_throughput_tp = [5752.14, 5749.34, 5756.78, 5758.32]
b1_throughput_fsdp = [6151.91, 6141.73, 6161.23, 6153.93]
b1_throughput_gpipe = [4599.87, 4610.19, 4592.17, 4601.18]
x_b7 = [1, 2, 3, 4]
b7_throughput_mLoRA = [1274.87, 2250.46, 2362.69, 2363.89]
b7_throughput_tp = [1614.18, 1610.26, 1620.07, 1613.34]
b7_throughput_fsdp = [1695.37, 1705.97, 1686.05, 1693.45]
b7_throughput_gpipe = [1284.27, 1273.89, 1272.14, 1279.64]
x_b13 = [1, 2, 3, 4]
b13_throughput_mLoRA = [723.21, 1280.54, 1286.54, 1282.54]
b13_throughput_tp = [875, 877, 870, 878]
b13_throughput_fsdp = [0, 0, 0, 0] # for the error
b13_throughput_gpipe = [723.21, 719.21, 726.21, 723.21]
ax[0].set_ylabel("Throughput (tokens/s)", fontsize=16)
ax[0].plot(x_b1, b1_throughput_fsdp, color=c_3, marker="v")
ax[0].plot(x_b1, b1_throughput_tp, color=c_2, marker="o")
ax[0].plot(x_b1, b1_throughput_gpipe, color=c_1, marker="^")
ax[0].plot(x_b1, b1_throughput_mLoRA, color=c_4, marker="*")
ax[0].set_ylim(3000, 12000)
ax[0].set_xticks([1, 2, 3, 4])
ax[0].set_xticklabels(["1", "2", "3", "4"])
ax[0].set_yticks([5000, 7000, 9000, 11000])
ax[0].set_yticklabels(
["5k", "7k", "9k", "11k"], rotation=90, ha="center", va="center", fontsize=12
)
ax[0].tick_params(pad=7)
ax[0].set_title("(a) TinyLlama-1.1B", fontsize=16)
ax[1].plot(x_b7, b7_throughput_gpipe, color=c_1, label="1F1B", marker="^")
ax[1].plot(x_b7, b7_throughput_tp, color=c_2, label="TP", marker="o")
ax[1].plot(x_b7, b7_throughput_fsdp, color=c_3, label="FSDP", marker="v")
ax[1].plot(x_b7, b7_throughput_mLoRA, color=c_4, label="LoRAPP", marker="*")
ax[1].set_ylim(1000, 2500)
ax[1].set_xticks([1, 2, 3, 4])
ax[1].set_xticklabels(["1", "2", "3", "4"])
ax[1].set_yticks([1200, 1700, 2200])
ax[1].set_yticklabels(
["1200", "1700", "2200"], rotation=90, ha="center", va="center", fontsize=12
)
ax[1].tick_params(pad=7)
ax[1].set_title("(b) Llama-2-7B", fontsize=16)
ax[1].set_xlabel("Number of simultaneously trained LoRA adapters", fontsize=16)
ax[2].plot(
x_b13,
b13_throughput_fsdp,
color=c_3,
marker="x",
markerfacecolor="r",
markeredgecolor="r",
)
ax[2].plot(x_b13, b13_throughput_tp, color=c_2, marker="o")
ax[2].plot(x_b13, b13_throughput_gpipe, color=c_1, marker="^")
ax[2].plot(x_b13, b13_throughput_mLoRA, color=c_4, marker="*")
ax[2].set_ylim(400, 1500)
ax[2].set_xticks([1, 2, 3, 4])
ax[2].set_xticklabels(["1", "2", "3", "4"])
ax[2].set_yticks([600, 950, 1300])
ax[2].set_yticklabels(
["650", "950", "1300"], rotation=90, ha="center", va="center", fontsize=12
)
ax[2].tick_params(pad=7)
ax[2].set_title("(c) Llama-2-13B", fontsize=16)
ax[2].text(
0.95,
0.05,
"FSDP : OOM",
fontsize=14,
va="bottom",
ha="right",
transform=ax[2].transAxes,
color="r",
style="italic",
)
fig.legend(
ncol=4,
bbox_to_anchor=(0.97, 1.2),
fancybox=False,
framealpha=0.0,
fontsize=16,
)
# plt.show()
plt.savefig("lorapp-task.pdf", bbox_inches="tight", dpi=1000)
return (
ax,
b13_throughput_fsdp,
b13_throughput_gpipe,
b13_throughput_mLoRA,
b13_throughput_tp,
b1_throughput_fsdp,
b1_throughput_gpipe,
b1_throughput_mLoRA,
b1_throughput_tp,
b7_throughput_fsdp,
b7_throughput_gpipe,
b7_throughput_mLoRA,
b7_throughput_tp,
c_1,
c_2,
c_3,
c_4,
fig,
x,
x_b1,
x_b13,
x_b7,
)
if __name__ == "__main__":
app.run()