import marimo __generated_with = "0.9.17" app = marimo.App(width="medium") @app.cell def __(): import matplotlib.pyplot as plt import numpy as np return np, plt @app.cell def __(plt): plt.rcParams['font.family'] = 'Times New Roman' plt.rcParams['font.size'] = 16 return @app.cell(hide_code=True) def __(np, plt): x = np.arange(4) fig, ax = plt.subplots( figsize=(7, 3.5), ncols=3, nrows=2, layout="constrained" ) c_1 = (230 / 255, 241 / 255, 243 / 255) c_2 = (0, 0, 0) c_3 = (255 / 255, 223 / 255, 146 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) x = [1, 2, 3, 4, 5, 6, 7, 8] ticks = [1, 2, 3, 4, 5, 6, 7, 8] ticks_label = ["1", "2", "3", "4", "5", "6", "7", "8"] y_torch = [ 20.29159868, 39.16359761, 55.65218289, 73.73416966, 88.12034672, 104.6562161, 124.9680397, 138.4435594, ] y_batchlora = [ 24.05060625, 36.4906544, 49.58459261, 62.96516142, 75.19585842, 86.59812198, 104.7908515, 116.5208559, ] y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)] print(y_imporve) ax[0][0].plot(x, y_torch, color=c_2, label="Operator without Graph Pruning", marker="o") ax[0][0].plot( x, y_batchlora, color=c_4, label="Operator with Graph Pruning", marker="*" ) ax[0][0].set_ylim(0, 170) ax[0][0].set_xticks(ticks) ax[0][0].set_xticklabels(ticks_label) ax[0][0].set_yticks([0, 50, 100, 150]) ax[0][0].set_yticklabels( ["0", "50", "100", "150"], rotation=90, ha="center", va="center", fontsize=12 ) ax[0][0].set_ylabel("Time (us)") ax[0][0].tick_params(pad=7) ax[0][0].text( 0.95, 0.02, "Model:1.1B", fontsize=14, va="bottom", ha="right", transform=ax[0][0].transAxes, ) # iax = ax[0][0].twinx() # iax.plot(x, y_imporve, color=c_3, label="Improve", linestyle="dashdot") # iax.set_ylim(-20, 20) # iax.set_yticks([-20, -10, 0, 10, 20]) # iax.set_yticklabels([], # rotation=-90, ha="center", va="center") # iax.tick_params(pad=7) x = [1, 2, 3, 4, 5] ticks = [1, 2, 3, 4, 5] ticks_label = ["1", "2", "3", "4", "5"] y_torch = [21.9846773, 40.29510077, 59.41132316, 77.31547579, 95.81772611] y_batchlora = [23.42831809, 37.3211666, 51.31030688, 65.47136931, 79.53268476] y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)] ax[0][1].plot(x, y_torch, color=c_2, marker="o") ax[0][1].plot(x, y_batchlora, color=c_4, marker="*") ax[0][1].set_ylim(0, 170) ax[0][1].set_xticks(ticks) ax[0][1].set_xticklabels(ticks_label) ax[0][1].set_yticks([0, 50, 100, 150]) ax[0][1].set_yticklabels( ["0", "50", "100", "150"], rotation=90, ha="center", va="center", fontsize=12 ) ax[0][1].tick_params(pad=7) # iax = ax[0][1].twinx() # iax.plot(x, y_imporve, color=c_3, linestyle="dashdot") # iax.set_ylim(-20, 20) # iax.set_yticks([-20, -10, 0, 10, 20]) # iax.set_yticklabels([], # rotation=-90, ha="center", va="center") # iax.tick_params(pad=7) ax[0][1].text( 0.95, 0.02, "Model:7B", fontsize=14, va="bottom", ha="right", transform=ax[0][1].transAxes, ) x = [1, 2, 3] ticks = [1, 2, 3] ticks_label = ["1", "2", "3"] y_torch = [22.37562835, 42.73959994, 61.96534634] y_batchlora = [22.80589938, 39.08431157, 54.03284915] y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)] ax[0][2].plot(x, y_torch, color=c_2, marker="o") ax[0][2].plot(x, y_batchlora, color=c_4, marker="*") ax[0][2].set_ylim(0, 85) ax[0][2].set_xticks(ticks) ax[0][2].set_xticklabels(ticks_label) ax[0][2].set_yticks([0, 25, 50, 75]) ax[0][2].set_yticklabels( ["0", "25", "50", "75"], rotation=90, ha="center", va="center", fontsize=12 ) ax[0][2].tick_params(pad=7) # iax = ax[0][2].twinx() # iax.plot(x, y_imporve, color=c_3, linestyle="dashdot") # iax.set_ylim(-20, 20) # iax.set_yticks([-20, -10, 0, 10, 20]) # iax.set_yticklabels(["-20", "-10", "0", "10", "20"], # rotation=-90, ha="center", va="center") # iax.tick_params(pad=7) ax[0][2].text( 0.95, 0.02, "Model:13B", fontsize=14, va="bottom", ha="right", transform=ax[0][2].transAxes, ) x = [1, 2, 3, 4, 5, 6, 7, 8] ticks = [1, 2, 3, 4, 5, 6, 7, 8] ticks_label = ["1", "2", "3", "4", "5", "6", "7", "8"] y_torch = [ 3.501786362, 5.201629498, 6.926035673, 8.653747242, 10.38349666, 12.11389446, 13.84254159, 15.57186355, ] y_batchlora = [ 3.393746125, 4.998805098, 6.571097296, 8.141429216, 9.714332745, 11.28814712, 12.85960523, 14.43241727, ] y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)] ax[1][0].plot(x, y_torch, color=c_2, marker="o") ax[1][0].plot(x, y_batchlora, color=c_4, marker="*") ax[1][0].set_ylim(0, 18) ax[1][0].set_xticks(ticks) ax[1][0].set_xticklabels(ticks_label) ax[1][0].set_ylabel("Peak Memory (GB)") ax[1][0].set_yticks([0, 5, 10, 15]) ax[1][0].set_yticklabels( ["0", "5", "10", "15"], rotation=90, ha="center", va="center", fontsize=12 ) ax[1][0].tick_params(pad=7) # iax = ax[1][0].twinx() # iax.plot(x, y_imporve, color=c_3, linestyle="dashdot") # iax.set_ylim(0, 10) # iax.set_yticks([0, 5, 10]) # iax.set_yticklabels([], # rotation=-90, ha="center", va="center") # iax.tick_params(pad=7) ax[1][0].text( 0.95, 0.02, "Model:1.1B", fontsize=14, va="bottom", ha="right", transform=ax[1][0].transAxes, ) x = [1, 2, 3, 4, 5] ticks = [1, 2, 3, 4, 5] ticks_label = ["1", "2", "3", "4", "5"] y_torch = [11.09702569, 14.20903317, 17.3902113, 20.59260555, 23.79299152] y_batchlora = [10.8060544, 13.69216517, 16.51687164, 19.33916381, 22.15532633] y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)] ax[1][1].plot(x, y_torch, color=c_2, marker="o") ax[1][1].plot(x, y_batchlora, color=c_4, marker="*") ax[1][1].set_ylim(10, 27) ax[1][1].set_xticks(ticks) ax[1][1].set_xticklabels(ticks_label) ax[1][1].set_yticks([10, 15, 20, 25]) ax[1][1].set_yticklabels( ["10", "15", "20", "25"], rotation=90, ha="center", va="center", fontsize=12 ) ax[1][1].tick_params(pad=7) # iax = ax[1][1].twinx() # iax.plot(x, y_imporve, color=c_3, linestyle="dashdot") # iax.set_ylim(0, 10) # iax.set_yticks([0, 5, 10]) # iax.set_yticklabels([], # rotation=-90, ha="center", va="center") # iax.tick_params(pad=7) ax[1][1].text( 0.95, 0.02, "Model:7B", fontsize=14, va="bottom", ha="right", transform=ax[1][1].transAxes, ) x = [1, 2, 3] ticks = [1, 2, 3] ticks_label = ["1", "2", "3"] y_torch = [18.52022991, 22.66454649, 26.87934514] y_batchlora = [18.16663067, 22.02296134, 25.78340335] y_imporve = [(t - b) / t * 100 for t, b in zip(y_torch, y_batchlora)] ax[1][2].plot(x, y_torch, color=c_2, marker="o") ax[1][2].plot(x, y_batchlora, color=c_4, marker="*") ax[1][2].set_ylim(15, 32) ax[1][2].set_xticks(ticks) ax[1][2].set_xticklabels(ticks_label) ax[1][2].set_yticks([15, 20, 25, 30]) ax[1][2].set_yticklabels( ["15", "20", "25", "30"], rotation=90, ha="center", va="center", fontsize=12 ) ax[1][2].tick_params(pad=7) # iax = ax[1][2].twinx() # iax.plot(x, y_imporve, color=c_3, linestyle="dashdot") # iax.set_ylim(0, 10) # iax.set_yticks([0, 5, 10]) # iax.set_yticklabels(["0", "5", "10"], # rotation=-90, ha="center", va="center") # iax.tick_params(pad=7) # iax.set_ylabel("Percentage Increase") ax[1][2].text( 0.95, 0.02, "Model:13B", fontsize=14, va="bottom", ha="right", transform=ax[1][2].transAxes, ) ax[1][1].set_xlabel("Number of simultaneously trained LoRA adapters") ax[0][1].set_title( "(a) The average time of forward and backward in the LoRA Operator", fontsize=16, ) ax[1][1].set_title("(b) The peak memory used of LoRA Operator", fontsize=16) fig.legend( ncol=3, bbox_to_anchor=(1, 1.12), fancybox=False, framealpha=0.0, fontsize=14, ) plt.savefig("batchlora_op_cmp.pdf", bbox_inches="tight", dpi=1000) return ( ax, c_1, c_2, c_3, c_4, fig, ticks, ticks_label, x, y_batchlora, y_imporve, y_torch, ) if __name__ == "__main__": app.run()