paper_note/mlora/end_to_end_large_model.py
2025-03-05 20:38:41 +08:00

90 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import marimo
__generated_with = "0.9.10"
app = marimo.App(width="medium")
@app.cell
def __():
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
matplotlib.rcParams["text.usetex"] = False
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 16
c_1 = (139 / 255, 0 / 255, 0 / 255)
c_2 = (0, 0, 0)
c_3 = (191 / 255, 191 / 255, 191 / 255)
c_4 = (230 / 255, 109 / 255, 104 / 255)
total_token = 7473 * 512 * 10
pp = [234.34, 234.32, 234.38, 234]
mlora = [234, 291, 320, 318]
pp_avg_time = [total_token / tp / 60 / 60 for tp in pp]
mlora_avg_time = [total_token / tp / 60 / 60 for tp in mlora]
fig, ax = plt.subplots(figsize=(5.5, 2.5), constrained_layout=True)
x = [1, 2, 3, 4]
ax.plot(x, pp_avg_time, color=c_1, marker="v", label="1F1B")
ax.plot(x, mlora_avg_time, color=c_4, marker="*", label="mLoRA")
ax.set_xticks(x)
ax.set_xticklabels(["1", "2", "3", "4"])
ax.set_yticks([0, 40, 60])
ax.set_yticklabels(["0", "40", "60"])
ax.set_ylim(0, 60 + 1)
ax.set_title("(a) 70B A6000×4", fontsize=16)
ax.set_ylabel("Average task\ncompletion time (h)")
ax.set_xlabel("Number of simultaneously trained LoRA adapters", fontsize=16)
ax.text(
0.95,
0.05,
"FSDP : OOM\nTP : OOM",
fontsize=12,
va="bottom",
ha="right",
transform=ax.transAxes,
color=c_4,
style="italic",
)
fig.legend(
ncol=2,
bbox_to_anchor=(0.75, 0.4),
fancybox=False,
framealpha=0.0,
fontsize=14,
)
# plt.savefig("end_to_end_large_model.pdf", bbox_inches="tight", dpi=1000, format="pdf")
return (
ax,
c_1,
c_2,
c_3,
c_4,
fig,
matplotlib,
mlora,
mlora_avg_time,
np,
plt,
pp,
pp_avg_time,
total_token,
x,
)
if __name__ == "__main__":
app.run()