import marimo __generated_with = "0.9.17" app = marimo.App(width="medium") @app.cell def __(): import matplotlib.pyplot as plt import numpy as np plt.rcParams['font.family'] = 'Times New Roman' plt.rcParams['font.size'] = 16 c_1 = (139 / 255, 0 / 255, 0 / 255) c_2 = (0, 0, 0) c_3 = (191 / 255, 191 / 255, 191 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) return c_1, c_2, c_3, c_4, np, plt @app.cell def __(c_2, c_4, plt): fig, ax = plt.subplots(1, 2, figsize=(7, 3.5), constrained_layout=True) space_width = 3 / 22 bar_width = 3 * space_width x_ticks = [ bar_width, space_width + 3 * bar_width, 2 * space_width + 5 * bar_width, ] x_ticks_label = ["1.1B", "7B", "13B"] ax[1].bar( space_width, 45, bar_width, color=c_2, ) ax[1].bar( space_width + bar_width, 55, bar_width, color=c_4, label="Communication", ) ax[1].bar( 2 * space_width + 2 * bar_width, 75, bar_width, color=c_2, label="Computation", ) ax[1].bar( 2 * space_width + 3 * bar_width, 25, bar_width, color=c_4, ) ax[1].bar( 3 * space_width + 4 * bar_width, 85, bar_width, color=c_2, ) ax[1].bar( 3 * space_width + 5 * bar_width, 15, bar_width, color=c_4, ) ax[1].set_xticks(x_ticks) ax[1].set_xticklabels(x_ticks_label) ax[1].tick_params(bottom=False, labelsize=14, pad=7) ax[1].set_ylabel("The proportion of \nthe total time (%)", fontsize=16) ax[1].set_yticks([0, 50, 100]) ax[1].set_yticklabels( ["0", "50", "100"], rotation=90, ha="center", va="center" ) ax[1].set_title("(b)", fontsize=16) ax[1].set_xlabel("Base model with \ndifferent parameter scales", fontsize=16) ##### ##### ##### space_width = 3 / 22 bar_width = 3 * space_width x_ticks = [ bar_width, space_width + 3 * bar_width, 2 * space_width + 5 * bar_width, ] x_ticks_label = ["1.1B", "7B", "13B"] ax[0].bar( space_width, 7062.16, bar_width, color=c_2, ) ax[0].bar( space_width + bar_width, 11625, bar_width, color=c_4, label="With overlapping", ) ax[0].bar( 2 * space_width + 2 * bar_width, 1874.48, bar_width, color=c_2, label="Without overlapping", ) ax[0].bar( 2 * space_width + 3 * bar_width, 2270, bar_width, color=c_4, ) ax[0].bar( 3 * space_width + 4 * bar_width, 1102, bar_width, color=c_2, ) ax[0].bar( 3 * space_width + 5 * bar_width, 1280, bar_width, color=c_4, ) ax[0].set_xticks(x_ticks) ax[0].set_xticklabels(x_ticks_label) ax[0].tick_params(bottom=False, labelsize=14, pad=7) ax[0].set_ylabel("Throughput (tokens/s)", fontsize=16) ax[0].set_yticks([0, 2500, 5000, 7500, 10000]) ax[0].set_yticklabels( ["0", "2.5k", "5k", "7.5k", "10k"], rotation=90, ha="center", va="center" ) ax[0].set_title("(a)", fontsize=16) ax[0].set_xlabel("Base model with \ndifferent parameter scales", fontsize=16) ax[1].legend( ncol=1, bbox_to_anchor=(0.55, 1), fancybox=False, framealpha=0.0, fontsize=10, ) ax[0].legend( ncol=1, bbox_to_anchor=(0.35, 1), fancybox=False, framealpha=0.0, fontsize=10, ) plt.savefig("overlapping.pdf", bbox_inches="tight", dpi=1000) return ax, bar_width, fig, space_width, x_ticks, x_ticks_label if __name__ == "__main__": app.run()