import marimo __generated_with = "0.9.10" app = marimo.App(width="medium") @app.cell def __(): import matplotlib import matplotlib.pyplot as plt import numpy as np matplotlib.rcParams["text.usetex"] = False plt.rcParams["font.family"] = "Times New Roman" plt.rcParams["font.size"] = 16 c_1 = (139 / 255, 0 / 255, 0 / 255) c_2 = (0, 0, 0) c_3 = (191 / 255, 191 / 255, 191 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) total_token = 7473 * 512 * 10 pp = [234.34, 234.32, 234.38, 234] mlora = [234, 291, 320, 318] pp_avg_time = [total_token / tp / 60 / 60 for tp in pp] mlora_avg_time = [total_token / tp / 60 / 60 for tp in mlora] fig, ax = plt.subplots(figsize=(5.5, 2.5), constrained_layout=True) x = [1, 2, 3, 4] ax.plot(x, pp_avg_time, color=c_1, marker="v", label="1F1B") ax.plot(x, mlora_avg_time, color=c_4, marker="*", label="mLoRA") ax.set_xticks(x) ax.set_xticklabels(["1", "2", "3", "4"]) ax.set_yticks([0, 40, 60]) ax.set_yticklabels(["0", "40", "60"]) ax.set_ylim(0, 60 + 1) ax.set_title("(a) 70B A6000×4", fontsize=16) ax.set_ylabel("Average task\ncompletion time (h)") ax.set_xlabel("Number of simultaneously trained LoRA adapters", fontsize=16) ax.text( 0.95, 0.05, "FSDP : OOM\nTP : OOM", fontsize=12, va="bottom", ha="right", transform=ax.transAxes, color=c_4, style="italic", ) fig.legend( ncol=2, bbox_to_anchor=(0.75, 0.4), fancybox=False, framealpha=0.0, fontsize=14, ) # plt.savefig("end_to_end_large_model.pdf", bbox_inches="tight", dpi=1000, format="pdf") return ( ax, c_1, c_2, c_3, c_4, fig, matplotlib, mlora, mlora_avg_time, np, plt, pp, pp_avg_time, total_token, x, ) if __name__ == "__main__": app.run()