90 lines
2.0 KiB
Python
90 lines
2.0 KiB
Python
import marimo
|
||
|
||
__generated_with = "0.9.10"
|
||
app = marimo.App(width="medium")
|
||
|
||
|
||
@app.cell
|
||
def __():
|
||
import matplotlib
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
|
||
matplotlib.rcParams["text.usetex"] = False
|
||
plt.rcParams["font.family"] = "Times New Roman"
|
||
plt.rcParams["font.size"] = 16
|
||
|
||
c_1 = (139 / 255, 0 / 255, 0 / 255)
|
||
c_2 = (0, 0, 0)
|
||
c_3 = (191 / 255, 191 / 255, 191 / 255)
|
||
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
||
|
||
total_token = 7473 * 512 * 10
|
||
|
||
pp = [234.34, 234.32, 234.38, 234]
|
||
mlora = [234, 291, 320, 318]
|
||
|
||
pp_avg_time = [total_token / tp / 60 / 60 for tp in pp]
|
||
mlora_avg_time = [total_token / tp / 60 / 60 for tp in mlora]
|
||
|
||
fig, ax = plt.subplots(figsize=(5.5, 2.5), constrained_layout=True)
|
||
|
||
x = [1, 2, 3, 4]
|
||
|
||
ax.plot(x, pp_avg_time, color=c_1, marker="v", label="1F1B")
|
||
ax.plot(x, mlora_avg_time, color=c_4, marker="*", label="mLoRA")
|
||
ax.set_xticks(x)
|
||
ax.set_xticklabels(["1", "2", "3", "4"])
|
||
ax.set_yticks([0, 40, 60])
|
||
ax.set_yticklabels(["0", "40", "60"])
|
||
ax.set_ylim(0, 60 + 1)
|
||
|
||
ax.set_title("(a) 70B A6000×4", fontsize=16)
|
||
|
||
ax.set_ylabel("Average task\ncompletion time (h)")
|
||
|
||
ax.set_xlabel("Number of simultaneously trained LoRA adapters", fontsize=16)
|
||
|
||
ax.text(
|
||
0.95,
|
||
0.05,
|
||
"FSDP : OOM\nTP : OOM",
|
||
fontsize=12,
|
||
va="bottom",
|
||
ha="right",
|
||
transform=ax.transAxes,
|
||
color=c_4,
|
||
style="italic",
|
||
)
|
||
|
||
fig.legend(
|
||
ncol=2,
|
||
bbox_to_anchor=(0.75, 0.4),
|
||
fancybox=False,
|
||
framealpha=0.0,
|
||
fontsize=14,
|
||
)
|
||
|
||
# plt.savefig("end_to_end_large_model.pdf", bbox_inches="tight", dpi=1000, format="pdf")
|
||
return (
|
||
ax,
|
||
c_1,
|
||
c_2,
|
||
c_3,
|
||
c_4,
|
||
fig,
|
||
matplotlib,
|
||
mlora,
|
||
mlora_avg_time,
|
||
np,
|
||
plt,
|
||
pp,
|
||
pp_avg_time,
|
||
total_token,
|
||
x,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
app.run()
|