import marimo __generated_with = "0.7.0" app = marimo.App(width="medium") @app.cell def __(): import matplotlib.pyplot as plt import numpy as np import random return np, plt, random @app.cell def __(plt): plt.rcParams['font.family'] = 'Times New Roman' plt.rcParams['font.size'] = 16 return @app.cell(hide_code=True) def __(plt): fig, ax = plt.subplots(figsize=(7, 2.4), ncols=3, layout="constrained") space_width = 3 / 22 bar_width = 3 * space_width c_1 = (139 / 255, 0 / 255, 0 / 255) c_2 = (0, 0, 0) c_3 = (191 / 255, 191 / 255, 191 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) x_ticks = [ bar_width, space_width + 3 * bar_width, 2 * space_width + 5 * bar_width, ] x_ticks_label = ["1.1B", "7B", "13B"] ax[0].bar(space_width, 512 * 8 / 2812, bar_width, color=c_2) ax[0].bar(space_width + bar_width, 512 * 8 / 3054, bar_width, color=c_4) ax[0].bar( 2 * space_width + 2 * bar_width, 512 * 8 / 690.5880666, bar_width, color=c_2, ) ax[0].bar( 2 * space_width + 3 * bar_width, 512 * 8 / 727.7364507, bar_width, color=c_4, ) ax[0].bar( 3 * space_width + 4 * bar_width, 512 * 8 / 396.4798927, bar_width, color=c_2, ) ax[0].bar( 3 * space_width + 5 * bar_width, 512 * 8 / 405.9223082, bar_width, color=c_4, ) ax[0].set_xticks(x_ticks) ax[0].set_xticklabels(x_ticks_label, ha="center", va="center") ax[0].tick_params(bottom=False, labelsize=14, pad=7) ax[0].set_ylabel("Time (s)", fontsize=16) ax[0].set_title("(a) Training time", fontsize=16, pad=22) ax[1].bar(space_width, 10.1, bar_width, color=c_2) ax[1].bar(space_width + bar_width, 2.4, bar_width, color=c_4) ax[1].bar(2 * space_width + 2 * bar_width, 7.5, bar_width, color=c_2) ax[1].bar(2 * space_width + 3 * bar_width, 2.1, bar_width, color=c_4) ax[1].bar(3 * space_width + 4 * bar_width, 3.9, bar_width, color=c_2) ax[1].bar(3 * space_width + 5 * bar_width, 1.1, bar_width, color=c_4) ax[1].set_xticks(x_ticks) ax[1].set_xticklabels(x_ticks_label, ha="center", va="center") ax[1].set_ylabel("Percentage (%)", fontsize=16) ax[1].tick_params(bottom=False, labelsize=14, pad=7) ax[1].set_title("(b) Proportion of \nkernel launch time", fontsize=16) peft_kern_time = [12969965266, 56659466215, 96280798449] batchlora_kern_time = [12969965266, 55606549735, 96080798449] pt = [1.3109530583214795, 5.493869969292716, 9.928009143652595] bt = [1.309003274394237, 5.493329344922422, 9.919060664229837] ax[2].bar(space_width, pt[0], bar_width, label="PEFT", color=c_2) ax[2].bar( space_width + bar_width, bt[0], bar_width, label="BatchLoRA", color=c_4 ) ax[2].bar(2 * space_width + 2 * bar_width, pt[1], bar_width, color=c_2) ax[2].bar(2 * space_width + 3 * bar_width, bt[1], bar_width, color=c_4) ax[2].bar(3 * space_width + 4 * bar_width, pt[2], bar_width, color=c_2) ax[2].bar(3 * space_width + 5 * bar_width, bt[2], bar_width, color=c_4) ax[2].set_xticks(x_ticks) ax[2].set_xticklabels(x_ticks_label, ha="center", va="center") ax[2].set_ylabel("Time (s)", fontsize=16) ax[2].tick_params(bottom=False, labelsize=14, pad=7) ax[2].set_title("(c) Kernel execution time", fontsize=16, pad=22) fig.legend( ncol=2, bbox_to_anchor=(0.8, 1.2), fancybox=False, framealpha=0.0, fontsize=16, ) # plt.savefig("batchlora_cmp.pdf", bbox_inches="tight", dpi=1000) return ( ax, bar_width, batchlora_kern_time, bt, c_1, c_2, c_3, c_4, fig, peft_kern_time, pt, space_width, x_ticks, x_ticks_label, ) if __name__ == "__main__": app.run()