import marimo __generated_with = "0.7.12" app = marimo.App(width="medium") @app.cell def __(): import matplotlib.pyplot as plt import numpy as np return np, plt @app.cell def __(plt): plt.rcParams["font.family"] = "Times New Roman" plt.rcParams["font.size"] = 16 return @app.cell def __(np, plt): x = np.arange(4) fig, ax = plt.subplots(figsize=(7, 2.6), ncols=3, layout="constrained") c_1 = (139 / 255, 0 / 255, 0 / 255) c_2 = (0, 0, 0) c_3 = (191 / 255, 191 / 255, 191 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) # tp = 8 * l * b * h * (d-1)/d # pp = 2 * b * h * (d - 1) token_size = 512 * 8 * 4 / 1024 / 1024 / 1024 x = [2, 3, 4, 5, 6, 7, 8] l1 = 22 h1 = 2048 l2 = 32 h2 = 4096 l3 = 40 h3 = 5120 ticks = [2, 3, 4, 5, 6, 7, 8] ticks_label = ["2", "3", "4", "5", "6", "7", "8"] y_ticks = [0, 5, 10, 15, 20] y_ticks_label = ["0", "5", "10", "15", "20"] y_tp = [8 * l1 * token_size * h1 * (d - 1) / d for d in x] y_pp = [2 * token_size * h1 * (d - 1) for d in x] ax[0].plot(x, y_tp, color=c_2, label="TP", marker="o") ax[0].plot(x, y_pp, color=c_4, label="LoRAPP and 1F1B", marker="*") ax[0].set_ylim([-1, 25]) ax[0].set_xticks(ticks) ax[0].set_xticklabels(ticks_label) ax[0].set_yticks(y_ticks) ax[0].set_yticklabels(y_ticks_label, rotation=90, ha="center", va="center") ax[0].set_title("(a) TinyLlama-1.1B", fontsize=16) ax[0].set_ylabel("Communication Cost (GB)", fontsize=16) ax[0].set_xlabel("Number of GPUs", fontsize=16) ax[0].tick_params(labelsize=16, pad=7) y_tp = [8 * l2 * token_size * h2 * (d - 1) / d for d in x] y_pp = [2 * token_size * h2 * (d - 1) for d in x] ax[1].plot(x, y_tp, color=c_2, marker="o") ax[1].plot(x, y_pp, color=c_4, marker="*") ax[1].set_ylim([-1, 25]) ax[1].set_xticks(ticks) ax[1].set_xticklabels(ticks_label) ax[1].set_yticks(y_ticks) ax[1].set_yticklabels(y_ticks_label, rotation=90, ha="center", va="center") ax[1].set_title("(b) Llama-2-7B", fontsize=16) ax[1].set_xlabel("Number of GPUs", fontsize=16) ax[1].tick_params(labelsize=16, pad=7) y_tp = [8 * l3 * token_size * h3 * (d - 1) / d for d in x] y_pp = [2 * token_size * h3 * (d - 1) for d in x] ax[2].plot(x, y_tp, color=c_2, marker="o") ax[2].plot(x, y_pp, color=c_4, marker="*") ax[2].set_ylim([-1, 25]) ax[2].set_xticks(ticks) ax[2].set_xticklabels(ticks_label) ax[2].set_yticks(y_ticks) ax[2].set_yticklabels(y_ticks_label, rotation=90, ha="center", va="center") ax[2].set_title("(c) Llama-2-13B", fontsize=16) ax[2].tick_params(labelsize=16, pad=7) ax[2].set_xlabel("Number of GPUs", fontsize=16) fig.legend( ncol=2, bbox_to_anchor=(0.8, 1.17), fancybox=False, framealpha=0.0, fontsize=16, ) # plt.show() plt.savefig("pp_cmp_com_cost.pdf", bbox_inches="tight", dpi=1000) return ( ax, c_1, c_2, c_3, c_4, fig, h1, h2, h3, l1, l2, l3, ticks, ticks_label, token_size, x, y_pp, y_ticks, y_ticks_label, y_tp, ) if __name__ == "__main__": app.run()