import marimo __generated_with = "0.9.17" app = marimo.App(width="medium") @app.cell def __(): import matplotlib import matplotlib.pyplot as plt import numpy as np return matplotlib, np, plt @app.cell def __(matplotlib, plt): matplotlib.rcParams["text.usetex"] = False plt.rcParams["font.family"] = "Times New Roman" plt.rcParams["font.size"] = 16 return @app.cell def __(np, plt): x = np.arange(4) fig, ax = plt.subplots( figsize=(14, 14 / 2.2), ncols=4, nrows=3, layout="constrained" ) c_1 = (139 / 255, 0 / 255, 0 / 255) c_2 = (0, 0, 0) c_3 = (191 / 255, 191 / 255, 191 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) total_token = 7473 * 512 * 10 # A6000 单卡 b1_tp_m_s = [ 2857.40228, 3016.124377, 3043.99588, 3047.335256, 3051.551977, 3051.512532, 3048.015064, 3047.108509, 3048.642661, 3051.840965, 3047.57159, 3047.861865, ] b1_tp_p_s = [ 2857.389389, 2842, 2851, 2847, 2853, 2841, 2843, 2851, 2850, 2851.3, 2849, 2852, ] b7_tp_m_s = [702.5522131, 716.7349242, 722.2261427, 725.5761517, 727.0030057] b7_tp_p_s = [702.6845352, 699.1034186, 701.3211972, 700.1283237, 700.1098232] b13_tp_m_s = [398.8387303, 403.5820717, 405.8601994] b13_tp_p_s = [398.7009553, 398.4052117, 399.1230098] # A6000 4卡 b1_throughput_mLoRA = [ 4600.45, 8664.91, 10118.36, 10184.44, 10119, 11157, 11157, 11530, 11580, 11580, 11600, 11600, ] b1_throughput_tp = [ 5752.14, 5749.34, 5756.78, 5758.32, 5753.14, 5753.34, 5756.78, 5756.32, 5758.14, 5749.34, 5753.78, 5754.32, ] b1_throughput_fsdp = [ 6151.91, 6141.73, 6161.23, 6153.93, 6153.91, 6146.73, 6161.23, 6151.93, 6157.91, 6143.73, 6161.23, 6153.93, ] b1_throughput_gpipe = [ 4599.87, 4610.19, 4592.17, 4601.18, 4598.87, 4600.19, 4593.17, 4601.18, 4599.87, 4610.19, 4592.17, 4603.18, ] b7_throughput_mLoRA = [1274.87, 2250.46, 2362.69, 2363.89] b7_throughput_tp = [1614.18, 1610.26, 1620.07, 1613.34] b7_throughput_fsdp = [1695.37, 1705.97, 1686.05, 1693.45] b7_throughput_gpipe = [1284.27, 1273.89, 1272.14, 1279.64] b13_throughput_mLoRA = [723.21, 1280.54] b13_throughput_tp = [875, 877] b13_throughput_fsdp = [0, 0, 0, 0] # for the error b13_throughput_gpipe = [723.21, 719.21] b70_throughput_mLoRA = [234, 291, 320, 318] b70_throughput_gpipe = [234.34, 234.32, 234.38, 234] # 3090 8卡 b1_4090_mlora = [ 319.61, 580.34, 663.12, 799.69, 800.96, 813.64, 812.92, 814.93, ] b1_4090_tp = [35.17, 35.33, 35.32, 35.32, 35.38, 35.34, 35.33, 35.34] b1_4090_fsdp = [62.79, 62.31, 60.03, 61.53, 62.79, 60.37, 61.23, 63.11] b1_4090_gpipe = [ 318.99, 319.16, 319.61, 319.42, 319.23, 319.61, 319.80, 319.96, ] b7_4090_mlora = [576.79, 671.39, 702.91, 715.85] b7_4090_gpipe = [578.97, 581.46, 573.70, 580.84] b7_4090_fsdp = [] b7_4090_tp = [11.93, 12.09, 11.86, 12.10] b13_4090_mlora = [524.57, 694.34] b13_4090_gpipe = [528.47, 522.06] b13_4090_fsdp = [] b13_4090_tp = [] # 绘制 A6000 4卡 x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] ticks = [2, 4, 6, 8, 10, 12] ticks_label = ["2", "4", "6", "8", "10", "12"] m_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b1_throughput_mLoRA)), b1_throughput_mLoRA) ] tp_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b1_throughput_tp)), b1_throughput_tp) ] fsdp_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b1_throughput_fsdp)), b1_throughput_fsdp) ] g_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b1_throughput_gpipe)), b1_throughput_gpipe) ] ax[0][0].plot(x, g_avg_time, color=c_1, marker="v") ax[0][0].plot(x, fsdp_avg_time, color=c_3, marker="^") ax[0][0].plot(x, tp_avg_time, color=c_2, marker="o") ax[0][0].plot(x, m_avg_time, color=c_4, marker="*") ax[0][0].set_xticks(ticks) ax[0][0].set_xticklabels(ticks_label) x = [1, 2, 3, 4] ticks = [1, 2, 3, 4] ticks_label = ["1", "2", "3", "4"] m_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b7_throughput_mLoRA)), b7_throughput_mLoRA) ] tp_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b7_throughput_tp)), b7_throughput_tp) ] fsdp_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b7_throughput_fsdp)), b7_throughput_fsdp) ] g_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b7_throughput_gpipe)), b7_throughput_gpipe) ] ax[0][1].plot(x, g_avg_time, color=c_1, marker="v", label="1F1B") ax[0][1].plot(x, fsdp_avg_time, color=c_3, marker="^", label="FSDP") ax[0][1].plot(x, tp_avg_time, color=c_2, marker="o", label="TP") ax[0][1].plot(x, m_avg_time, color=c_4, marker="*") ax[0][1].set_xticks(ticks) ax[0][1].set_xticklabels(ticks_label) x = [1, 2] ticks = [1, 2] ticks_label = ["1", "2"] m_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip( range(0, len(b13_throughput_mLoRA)), b13_throughput_mLoRA ) ] tp_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b13_throughput_tp)), b13_throughput_tp) ] g_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip( range(0, len(b13_throughput_gpipe)), b13_throughput_gpipe ) ] ax[0][2].plot(x, g_avg_time, color=c_1, marker="v") ax[0][2].plot(x, tp_avg_time, color=c_2, marker="o") ax[0][2].plot(x, m_avg_time, color=c_4, marker="*") ax[0][2].set_xticks(ticks) ax[0][2].set_xticklabels(ticks_label) x = [1, 2, 3, 4] ticks = [1, 2, 3, 4] ticks_label = ["1", "2", "3", "4"] m_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip( range(0, len(b70_throughput_mLoRA)), b70_throughput_mLoRA ) ] g_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip( range(0, len(b70_throughput_gpipe)), b70_throughput_gpipe ) ] ax[0][3].plot(x, g_avg_time, color=c_1, marker="v") ax[0][3].plot(x, m_avg_time, color=c_4, marker="*") ax[0][3].set_xticks(ticks) ax[0][3].set_xticklabels(ticks_label) ## END # 绘制 3090 8 卡 x = [1, 2, 3, 4, 5, 6, 7, 8] ticks = [2, 4, 6, 8] ticks_label = ["2", "4", "6", "8"] m_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b1_4090_mlora)), b1_4090_mlora) ] tp_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b1_4090_tp)), b1_4090_tp) ] fsdp_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b1_4090_fsdp)), b1_4090_fsdp) ] g_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b1_4090_gpipe)), b1_4090_gpipe) ] ax[1][0].plot(x, g_avg_time, color=c_1, marker="v") ax[1][0].plot(x, fsdp_avg_time, color=c_3, marker="^") ax[1][0].plot(x, tp_avg_time, color=c_2, marker="o") ax[1][0].plot(x, m_avg_time, color=c_4, marker="*") ax[1][0].set_xticks(ticks) ax[1][0].set_xticklabels(ticks_label) x = [1, 2, 3, 4] ticks = [1, 2, 3, 4] ticks_label = ["1", "2", "3", "4"] m_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b7_4090_mlora)), b7_4090_mlora) ] tp_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b7_4090_tp)), b7_4090_tp) ] fsdp_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b7_4090_fsdp)), b7_4090_fsdp) ] g_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b7_4090_gpipe)), b7_4090_gpipe) ] ax[1][1].plot(x, g_avg_time, color=c_1, marker="v") # ax[1][1].plot([1], fsdp_avg_time, color=c_3, marker="^") # ax[1][1].plot(x, tp_avg_time, color=c_2, marker="o") ax[1][1].plot(x, m_avg_time, color=c_4, marker="*") ax[1][1].set_xticks(ticks) ax[1][1].set_xticklabels(ticks_label) x = [1, 2] ticks = [1, 2] ticks_label = ["1", "2"] m_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b13_4090_mlora)), b13_4090_mlora) ] tp_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b13_4090_tp)), b13_4090_tp) ] fsdp_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b13_4090_fsdp)), b13_4090_fsdp) ] g_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b13_4090_gpipe)), b13_4090_gpipe) ] ax[1][2].plot(x, g_avg_time, color=c_1, marker="v") # ax[1][2].plot([1], fsdp_avg_time, color=c_3, marker="^") # ax[1][2].plot(x, tp_avg_time, color=c_2, marker="o") ax[1][2].plot(x, m_avg_time, color=c_4, marker="*") ax[1][2].set_xticks(ticks) ax[1][2].set_xticklabels(ticks_label) x = [1, 2] ticks = [1, 2] ticks_label = ["1", "2"] ax[1][3].set_xticks(ticks) ax[1][3].set_xticklabels(ticks_label) # END # 绘制 A6000 单卡 x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] ticks = [2, 4, 6, 8, 10, 12] ticks_label = ["2", "4", "6", "8", "10", "12"] m_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b1_tp_m_s)), b1_tp_m_s) ] p_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b1_tp_p_s)), b1_tp_p_s) ] ax[2][0].plot(x, p_avg_time, color=c_2, marker="^", label="PEFT") ax[2][0].plot(x, m_avg_time, color=c_4, marker="*", label="mLoRA") ax[2][0].set_xticks(ticks) ax[2][0].set_xticklabels(ticks_label) x = [1, 2, 3, 4, 5] ticks = [1, 2, 3, 4, 5] ticks_label = ["1", "2", "3", "4", "5"] m_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b7_tp_m_s)), b7_tp_m_s) ] p_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b7_tp_p_s)), b7_tp_p_s) ] ax[2][1].plot(x, p_avg_time, color=c_2, marker="^") ax[2][1].plot(x, m_avg_time, color=c_4, marker="*") ax[2][1].set_xticks(ticks) ax[2][1].set_xticklabels(ticks_label) x = [1, 2, 3] ticks = [1, 2, 3] ticks_label = ["1", "2", "3"] m_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b13_tp_m_s)), b13_tp_m_s) ] p_avg_time = [ total_token * (cnt + 1) / tp / 60 / 60 for cnt, tp in zip(range(0, len(b13_tp_p_s)), b13_tp_p_s) ] ax[2][2].plot(x, p_avg_time, color=c_2, marker="^") ax[2][2].plot(x, m_avg_time, color=c_4, marker="*") ax[2][2].set_xticks(ticks) ax[2][2].set_xticklabels(ticks_label) x = [1, 2] ticks = [1, 2] ticks_label = ["1", "2"] ax[2][3].set_xticks(ticks) ax[2][3].set_xticklabels(ticks_label) ## END ax[0][0].set_ylim(0, 30) ax[0][1].set_ylim(0, 40) ax[0][2].set_ylim(0, 40) ax[0][3].set_ylim(0, 200) ax[1][0].set_ylim(0, 2500) ax[1][1].set_ylim(0, 80) ax[1][2].set_ylim(0, 50) ax[1][3].set_ylim(0, 1000) ax[2][0].set_ylim(0, 50) ax[2][1].set_ylim(0, 100) ax[2][2].set_ylim(0, 100) ax[2][3].set_ylim(0, 100) ax[0][2].set_xlim(0.8, 3 - 0.8) ax[1][0].set_xlim(0.5, 9 - 0.5) ax[1][1].set_xlim(0.5, 5 - 0.5) ax[1][2].set_xlim(0.8, 3 - 0.8) ax[2][2].set_xlim(0.8, 4 - 0.8) ax[0][0].set_title("(a) 1.1B A6000×4", fontsize=16) ax[0][1].set_title("(b) 7B A6000×4", fontsize=16) ax[0][2].set_title("(c) 13B A6000×4", fontsize=16) ax[0][3].set_title("(d) 70B A6000×4", fontsize=16) ax[1][0].set_title("(e) 1.1B 3090×8", fontsize=16) ax[1][1].set_title("(f) 7B 3090×8", fontsize=16) ax[1][2].set_title("(g) 13B 3090×8", fontsize=16) ax[1][3].set_title("(h) 70B 3090×8", fontsize=16) ax[2][0].set_title("(i) 1.1B A6000", fontsize=16) ax[2][1].set_title("(j) 7B A6000", fontsize=16) ax[2][2].set_title("(k) 13B A6000", fontsize=16) ax[2][3].set_title("(l) 70B A6000", fontsize=16) ax[0][0].set_ylabel("Task\ncompletion time (h)") ax[1][0].set_ylabel("Task\ncompletion time (h)") ax[2][0].set_ylabel("Task\ncompletion time (h)") ax[0][2].text( 0.9, 0.8, "FSDP : OOM", fontsize=12, va="bottom", ha="right", transform=ax[0][2].transAxes, color=c_4, ) ax[0][3].text( 0.9, 0.1, "FSDP : OOM\nTP : OOM", fontsize=12, va="bottom", ha="right", transform=ax[0][3].transAxes, color=c_4, ) ax[1][1].text( 0.9, 0.1, "FSDP : OOM\nTP : about one month", fontsize=12, va="bottom", ha="right", transform=ax[1][1].transAxes, color=c_4, ) ax[1][2].text( 0.9, 0.1, "FSDP : OOM\nTP : OOM", fontsize=12, va="bottom", ha="right", transform=ax[1][2].transAxes, color=c_4, ) ax[1][3].text( 0.5, 0.5, "FSDP : OOM\nTP : OOM\n1F1B : OOM\nmLoRA : OOM", fontsize=12, va="center", ha="center", transform=ax[1][3].transAxes, color=c_4, ) ax[2][3].text( 0.5, 0.5, "PEFT : OOM\nmLoRA : OOM", fontsize=12, va="center", ha="center", transform=ax[2][3].transAxes, color=c_4, ) fig.legend( ncol=5, bbox_to_anchor=(0.75, 1.05), fancybox=False, framealpha=0.0, fontsize=14, ) fig.supxlabel( "Number of trained LoRA adapters", fontsize=16, y=-0.03, ha="center", va="bottom", ) plt.savefig("end-to-end-total.pdf", bbox_inches="tight", dpi=1000) return ( ax, b13_4090_fsdp, b13_4090_gpipe, b13_4090_mlora, b13_4090_tp, b13_throughput_fsdp, b13_throughput_gpipe, b13_throughput_mLoRA, b13_throughput_tp, b13_tp_m_s, b13_tp_p_s, b1_4090_fsdp, b1_4090_gpipe, b1_4090_mlora, b1_4090_tp, b1_throughput_fsdp, b1_throughput_gpipe, b1_throughput_mLoRA, b1_throughput_tp, b1_tp_m_s, b1_tp_p_s, b70_throughput_gpipe, b70_throughput_mLoRA, b7_4090_fsdp, b7_4090_gpipe, b7_4090_mlora, b7_4090_tp, b7_throughput_fsdp, b7_throughput_gpipe, b7_throughput_mLoRA, b7_throughput_tp, b7_tp_m_s, b7_tp_p_s, c_1, c_2, c_3, c_4, fig, fsdp_avg_time, g_avg_time, m_avg_time, p_avg_time, ticks, ticks_label, total_token, tp_avg_time, x, ) if __name__ == "__main__": app.run()