import marimo __generated_with = "0.9.17" app = marimo.App(width="medium") @app.cell def __(): import matplotlib.pyplot as plt import numpy as np return np, plt @app.cell def __(plt): plt.rcParams['font.family'] = 'Times New Roman' plt.rcParams['font.size'] = 16 return @app.cell def __(np, plt): x = np.arange(4) lorapp = np.array([10748, 2363.89, 1280.54]) tp = np.array([5752.14, 1500, 875]) fsdp = np.array([6151.91, 1750, 0]) gpipe = np.array([4599.87, 1284.27, 723.21]) fig, ax = plt.subplots(figsize=(7, 2), ncols=3, layout="constrained") space_width = 3 / 17 bar_width = 3 * space_width c_1 = (139 / 255, 0 / 255, 0 / 255) c_2 = (0, 0, 0) c_3 = (191 / 255, 191 / 255, 191 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) ticks = [ space_width, 2 * space_width + bar_width, 3 * space_width + 2 * bar_width, 4 * space_width + 3 * bar_width, ] ticks_label = ["1F1B", "FSDP", "TP", "LoRAPP"] ax[0].bar(space_width, gpipe[0], bar_width, label="1F1B", color=c_1) ax[0].bar(2 * space_width + bar_width, tp[0], bar_width, label="TP", color=c_2) ax[0].bar( 3 * space_width + 2 * bar_width, fsdp[0], bar_width, label="FSDP", color=c_3, ) ax[0].bar( 4 * space_width + 3 * bar_width, lorapp[0], bar_width, label="LoRAPP", color=c_4, ) ax[0].set_ylabel("Throughput (tokens/s)", fontsize=12) ax[0].get_xaxis().set_visible(False) ax[0].set_ylim([0, 12000]) ax[0].set_yticks([1000, 4000, 7000, 10000]) ax[0].set_yticklabels( ["1k", "4k", "7k", "10k"], rotation=90, ha="center", va="center" ) ax[0].tick_params(bottom=False, labelsize=13, pad=7) ax[0].set_title("(a) TinyLlama-1.1B", fontsize=16) ax[1].bar(space_width, gpipe[1], bar_width, label="LoRAPP", color=c_1) ax[1].bar(2 * space_width + bar_width, tp[1], bar_width, label="TP", color=c_2) ax[1].bar( 3 * space_width + 2 * bar_width, fsdp[1], bar_width, label="FSDP", color=c_3, ) ax[1].bar( 4 * space_width + 3 * bar_width, lorapp[1], bar_width, label="1F1B", color=c_4, ) ax[1].get_xaxis().set_visible(False) ax[1].set_ylim([0, 2500]) ax[1].set_yticks([1000, 2000]) ax[1].set_yticklabels(["1000", "2000"], rotation=90, ha="center", va="center") ax[1].tick_params(bottom=False, labelsize=13, pad=7) ax[1].set_title("(b) Llama-2-7B", fontsize=16) ax[2].bar(space_width, gpipe[2], bar_width, label="1F1B", color=c_1) ax[2].bar(2 * space_width + bar_width, tp[2], bar_width, label="TP", color=c_2) ax[2].bar( 3 * space_width + 2 * bar_width, fsdp[2], bar_width, label="FSDP", color=c_3, ) ax[2].bar( 4 * space_width + 3 * bar_width, lorapp[2], bar_width, label="LoRAPP", color=c_4, ) # 隐藏x轴 ax[2].get_xaxis().set_visible(False) ax[2].set_ylim([0, 1500]) ax[2].set_yticks([400, 800, 1200]) ax[2].set_yticklabels( ["400", "800", "1200"], rotation=90, ha="center", va="center" ) ax[2].tick_params(bottom=False, labelsize=13, pad=7) ax[2].text( ticks[2], 100, "OOM", ha="center", va="center", color="r", style="italic", fontsize=14, ) ax[2].set_title("(c) Llama-2-13B", fontsize=16) plt.tight_layout() plt.legend( ncol=4, loc="upper center", bbox_to_anchor=(-0.8, 1.50), frameon=False ) plt.savefig("pp_cmp_tp.pdf", bbox_inches="tight", dpi=1000) return ( ax, bar_width, c_1, c_2, c_3, c_4, fig, fsdp, gpipe, lorapp, space_width, ticks, ticks_label, tp, x, ) if __name__ == "__main__": app.run()