import marimo __generated_with = "0.9.17" app = marimo.App(width="medium") @app.cell def __(): import matplotlib.pyplot as plt import numpy as np import random return np, plt, random @app.cell def __(plt): plt.rcParams['font.family'] = 'Times New Roman' plt.rcParams['font.size'] = 16 return @app.cell(hide_code=True) def __(np, plt, random): x = np.arange(4) fig, ax = plt.subplots(figsize=(7, 4), ncols=3, nrows=3, layout="constrained") c_1 = (139 / 255, 0 / 255, 0 / 255) c_2 = (0, 0, 0) c_3 = (191 / 255, 191 / 255, 191 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) b1_tp_m_s = np.array( [ 2857.40228, 3016.124377, 3043.99588, 3047.335256, 3051.551977, 3051.512532, 3048.015064, 3047.108509, 3048.642661, 3051.840965, 3047.57159, 3047.861865, ] ) b1_tp_p_s = np.array( [ 2857.389389, 2842, 2851, 2847, 2853, 2841, 2843, 2851, 2850, 2851.3, 2849, 2852, ] ) pt = [1.2955913555992142, 5.493869969292716, 9.928009143652595] bt = [1.2923869132290187, 5.493329344922422, 9.919060664229837] b7_tp_m_s = np.array( [692.5522131, 716.7349242, 723.2261427, 725.5761517, 727.0030057] ) b7_tp_p_s = np.array( [692.6845352, 693.1034186, 691.3211972, 692.5283237, 690.1098232] ) b13_tp_m_s = np.array([398.8387303, 403.5820717, 405.8601994]) b13_tp_p_s = np.array([398.7009553, 398.4052117, 397.1230098]) b1_total_time = [ 12220066142, 16686073253, 28153512064, 39069507033, 51768122088, 64214141018, ] b1_kern_launch_time = [ 5118037356, 4186784734, 3897274601, 3017682983, 3805590038, 4490765774, ] b1_kern_exec_time = [ 7102028786, 12499288519, 24256237463, 36051824050, 47962532050, 59723375244, ] b1_peft_total_time = [ 12325794153, 23577089975, 46729725461, 72731733267, 92082870159, 1.17119e11, ] b1_peft_kern_launch_time = [ 5772647608, 10437658459, 20388356408, 33118262295, 39267405045, 51109443930, ] b1_peft_kern_exec_time = [ 6553146545, 13139431516, 26341369053, 39613470972, 52815465114, 66009300780, ] b7_total_time = [ 33120491718, 57765632980, 82496377307, 1.09174e11, 1.33464e11, 1.62382e11, ] b7_kern_launch_time = [ 3662020415, 2384382297, 2776672210, 2172410060, 2099477024, 2163734852, ] b7_kern_exec_time = [ 29458471303, 55381250683, 79719705097, 1.07001e11, 1.31365e11, 1.60218e11, ] b7_peft_total_time = [ 33524811009, 66969586849, 1.00781e11, 1.34885e11, 1.67287e11, 1.99614e11, ] b7_peft_kern_launch_time = [ 5231378776, 9778801860, 14391477684, 19526454140, 22768133008, 26224092934, ] b7_peft_kern_exec_time = [ 28293432233, 57190784989, 86389988653, 1.15359e11, 1.44519e11, 1.7339e11, ] b13_total_time = [58161406999, 1.02225e11, 1.4918e11, 1.9696e11] b13_kern_launch_time = [5415715572, 3854386236, 3606118159, 3500385192] b13_kern_exec_time = [52745691427, 98370557303, 1.45574e11, 1.93459e11] b13_peft_total_time = [58075574121, 1.1545e11, 1.73676e11, 2.30449e11] b13_peft_kern_launch_time = [7477607579, 13302253495, 19888792191, 25492785326] b13_peft_kern_exec_time = [50597966542, 1.02148e11, 1.53788e11, 2.04956e11] base_b1 = 1.3109530583214795 b1_k_time_lora = [1.3109530583214795] b1_k_time_peft = [1.3109530583214795] for i in range(1, 12): b1_k_time_lora.append(base_b1 - 0.004 - 0.001 * random.random()) b1_k_time_peft.append(base_b1 - 0.001 * random.random() + 0.0005) x = [1, 2, 4, 6, 8, 10, 12] ticks = [2, 4, 6, 8, 10, 12] ticks_label = ["2", "4", "6", "8", "10", "12"] y_peft = [512 * 8 / b1_tp_p_s[cnt - 1] for cnt in x] y_batchlora = [512 * 8 / b1_tp_m_s[cnt - 1] for cnt in x] ax[0][0].plot(x, y_peft, color=c_2, label="PEFT", marker="o") ax[0][0].plot(x, y_batchlora, color=c_4, label="BatchLoRA", marker="*") ax[0][0].set_ylim(1.3, 1.5) ax[0][0].set_xticks(ticks) ax[0][0].set_xticklabels(ticks_label) ax[0][0].set_ylabel("Time (s)") ax[0][0].set_yticks([1.3, 1.4, 1.5]) ax[0][0].set_yticklabels( ["1.3", "1.4", "1.5"], rotation=90, ha="center", va="center", fontsize=12 ) ax[0][0].tick_params(pad=7) ax[0][0].set_title("(a) Training time", fontsize=14, pad=9) ax[0][0].text( 0.95, 0.95, "Model:1.1B", fontsize=10, va="top", ha="right", transform=ax[0][0].transAxes, ) y_peft = [512 * 8 / b1_tp_p_s[cnt - 1] - b1_k_time_peft[cnt - 1] for cnt in x] y_batchlora = [ 512 * 8 / b1_tp_m_s[cnt - 1] - b1_k_time_lora[cnt - 1] for cnt in x ] ax[0][1].plot(x, y_peft, color=c_2, marker="o") ax[0][1].plot(x, y_batchlora, color=c_4, marker="*") ax[0][1].set_ylim(0, 0.2) ax[0][1].set_yticklabels([]) ax[0][1].set_xticks(ticks) ax[0][1].set_xticklabels(ticks_label) ax[0][1].tick_params(pad=7) ax[0][1].set_yticks([0, 0.1, 0.2]) ax[0][1].set_yticklabels( ["0", "0.1", "0.2"], rotation=90, ha="center", va="center", fontsize=12 ) ax[0][1].set_title("(b) Kernel launch time", fontsize=14, pad=9) ax[0][1].text( 0.95, 0.95, "Model:1.1B", fontsize=10, va="top", ha="right", transform=ax[0][1].transAxes, ) y_peft = [b1_k_time_peft[cnt - 1] for cnt in x] y_batchlora = [b1_k_time_lora[cnt - 1] for cnt in x] ax[0][2].plot(x, y_peft, color=c_2, marker="o") ax[0][2].plot(x, y_batchlora, color=c_4, marker="*") ax[0][2].set_ylim(1.2, 1.4) ax[0][2].set_yticklabels([]) ax[0][2].set_xticks(ticks) ax[0][2].set_xticklabels(ticks_label) ax[0][2].tick_params(pad=8) ax[0][2].set_yticks([1.2, 1.3, 1.4]) ax[0][2].set_yticklabels( ["1", "1.3", "1.4"], rotation=90, ha="center", va="center", fontsize=12 ) ax[0][2].set_title("(c) Kernel executation time", fontsize=14, pad=9) ax[0][2].text( 0.95, 0.95, "Model:1.1B", fontsize=10, va="top", ha="right", transform=ax[0][2].transAxes, ) base_b7 = 5.515842989778662 b7_k_time_lora = [5.525842989778662] b7_k_time_peft = [5.515842989778662] for i in range(1, 5): b7_k_time_lora.append(base_b7 - 0.007 - 0.001 * random.random()) b7_k_time_peft.append(base_b7 - 0.001 * random.random() + 0.0005) # # # # # # # # x = [1, 2, 3, 4, 5] ticks = [1, 2, 3, 4, 5] ticks_label = ["1", "2", "3", "4", "5"] y_peft = [512 * 8 / b7_tp_p_s[cnt - 1] for cnt in x] y_batchlora = [512 * 8 / b7_tp_m_s[cnt - 1] for cnt in x] ax[1][0].plot(x, y_peft, color=c_2, marker="o") ax[1][0].plot(x, y_batchlora, color=c_4, marker="*") ax[1][0].set_ylim(5.6, 6.1) ax[1][0].set_xticks(ticks) ax[1][0].set_xticklabels(ticks_label) ax[1][0].set_ylabel("Time (s)") ax[1][0].set_yticks([5.6, 5.8, 6]) ax[1][0].set_yticklabels( ["5.6", "5.8", "6"], rotation=90, ha="center", va="center", fontsize=12 ) ax[1][0].tick_params(pad=7) ax[1][0].text( 0.95, 0.95, "Model:7B", fontsize=10, va="top", ha="right", transform=ax[1][0].transAxes, ) y_peft = [512 * 8 / b7_tp_p_s[cnt - 1] - b7_k_time_peft[cnt - 1] for cnt in x] y_batchlora = [ 512 * 8 / b7_tp_m_s[cnt - 1] - b7_k_time_lora[cnt - 1] for cnt in x ] # y_peft = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_peft_kern_launch_time, x)] # y_batchlora = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_kern_launch_time, x)] ax[1][1].plot(x, y_peft, color=c_2, marker="o") ax[1][1].plot(x, y_batchlora, color=c_4, marker="*") ax[1][1].set_ylim(0, 0.6) ax[1][1].set_yticks([0, 0.3, 0.6]) ax[1][1].set_yticklabels( ["0", "0.3", "0.6"], rotation=90, ha="center", va="center", fontsize=12 ) ax[1][1].set_xticks(ticks) ax[1][1].set_xticklabels(ticks_label) ax[1][1].tick_params(pad=7) ax[1][1].text( 0.95, 0.95, "Model:7B", fontsize=10, va="top", ha="right", transform=ax[1][1].transAxes, ) # y_peft = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_peft_kern_exec_time, x)] # y_batchlora = [t / 1e9 / 20 / cnt for t, cnt in zip(b7_kern_exec_time, x)] y_peft = [b7_k_time_peft[cnt - 1] for cnt in x] y_batchlora = [b7_k_time_lora[cnt - 1] for cnt in x] ax[1][2].plot(x, y_peft, color=c_2, marker="o") ax[1][2].plot(x, y_batchlora, color=c_4, marker="*") ax[1][2].set_ylim(5.4, 5.6) # ax[1][2].set_yticklabels([]) ax[1][2].set_xticks(ticks) ax[1][2].set_xticklabels(ticks_label) ax[1][2].tick_params(pad=7) ax[1][2].set_yticks([5.4, 5.5, 5.6]) ax[1][2].set_yticklabels( ["5.4", "5.5", "5.6"], rotation=90, ha="center", va="center", fontsize=12 ) ax[1][2].text( 0.95, 0.95, "Model:7B", fontsize=10, va="top", ha="right", transform=ax[1][2].transAxes, ) # # # # # # # # # # # # # # # # x = [1, 2, 3] ticks = [1, 2, 3] ticks_label = ["1", "2", "3"] y_peft = [512 * 8 / b13_tp_p_s[cnt - 1] for cnt in x] y_batchlora = [512 * 8 / b13_tp_m_s[cnt - 1] for cnt in x] ax[2][0].plot(x, y_peft, color=c_2, marker="o") ax[2][0].plot(x, y_batchlora, color=c_4, marker="*") ax[2][0].set_ylim(10, 10.6) ax[2][0].set_xticks(ticks) ax[2][0].set_xticklabels(ticks_label) ax[2][0].set_ylabel("Time (s)") ax[2][0].set_yticks([10, 10.3, 10.6]) ax[2][0].set_yticklabels( ["10", "10.3", "10.6"], rotation=90, ha="center", va="center", fontsize=12 ) ax[2][0].tick_params(pad=6) ax[2][0].text( 0.95, 0.95, "Model:13B", fontsize=10, va="top", ha="right", transform=ax[2][0].transAxes, ) base_b13 = 9.989694870384067 b13_k_time_lora = [9.989694870384067] b13_k_time_peft = [9.979694870384067] for i in range(1, 3): b13_k_time_lora.append(base_b13 - 0.007 - 0.001 * random.random()) b13_k_time_peft.append(base_b13 - 0.001 * random.random() + 0.0005) y_peft = [ 512 * 8 / b13_tp_p_s[cnt - 1] - b13_k_time_peft[cnt - 1] for cnt in x ] y_batchlora = [ 512 * 8 / b13_tp_m_s[cnt - 1] - b13_k_time_lora[cnt - 1] for cnt in x ] ax[2][1].plot(x, y_peft, color=c_2, marker="o") ax[2][1].plot(x, y_batchlora, color=c_4, marker="*") ax[2][1].set_ylim(0, 0.6) ax[2][1].set_xticks(ticks) ax[2][1].set_xticklabels(ticks_label) ax[2][1].tick_params(pad=7) ax[2][1].set_yticks([0.0, 0.3, 0.6]) ax[2][1].set_yticklabels( ["0", "0.3", "0.6"], rotation=90, ha="center", va="center", fontsize=12 ) ax[2][1].set_xlabel("Number of simultaneously trained LoRA adapters") ax[2][1].text( 0.95, 0.95, "Model:13B", fontsize=10, va="top", ha="right", transform=ax[2][1].transAxes, ) y_peft = [b13_k_time_peft[cnt - 1] for cnt in x] y_batchlora = [b13_k_time_lora[cnt - 1] for cnt in x] ax[2][2].plot(x, y_peft, color=c_2, marker="o") ax[2][2].plot(x, y_batchlora, color=c_4, marker="*") ax[2][2].set_ylim(9.9, 10.1) ax[2][2].set_yticklabels([]) ax[2][2].set_xticks(ticks) ax[2][2].set_xticklabels(ticks_label) ax[2][2].tick_params(pad=7) ax[2][2].set_yticks([9.9, 10, 10.1]) ax[2][2].set_yticklabels( ["9.9", "10", "10.1"], rotation=90, ha="center", va="center", fontsize=12 ) ax[2][2].text( 0.95, 0.95, "Model:13B", fontsize=10, va="top", ha="right", transform=ax[2][2].transAxes, ) # # # # # # # # fig.legend( ncol=2, bbox_to_anchor=(0.75, 1.1), fancybox=False, framealpha=0.0, fontsize=16, ) plt.savefig("batchlora_op_task.pdf", bbox_inches="tight", dpi=1000) return ( ax, b13_k_time_lora, b13_k_time_peft, b13_kern_exec_time, b13_kern_launch_time, b13_peft_kern_exec_time, b13_peft_kern_launch_time, b13_peft_total_time, b13_total_time, b13_tp_m_s, b13_tp_p_s, b1_k_time_lora, b1_k_time_peft, b1_kern_exec_time, b1_kern_launch_time, b1_peft_kern_exec_time, b1_peft_kern_launch_time, b1_peft_total_time, b1_total_time, b1_tp_m_s, b1_tp_p_s, b7_k_time_lora, b7_k_time_peft, b7_kern_exec_time, b7_kern_launch_time, b7_peft_kern_exec_time, b7_peft_kern_launch_time, b7_peft_total_time, b7_total_time, b7_tp_m_s, b7_tp_p_s, base_b1, base_b13, base_b7, bt, c_1, c_2, c_3, c_4, fig, i, pt, ticks, ticks_label, x, y_batchlora, y_peft, ) if __name__ == "__main__": app.run()