127 lines
3.3 KiB
Python
127 lines
3.3 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.7.12"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell
|
|
def __():
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
return np, plt
|
|
|
|
|
|
@app.cell
|
|
def __(plt):
|
|
plt.rcParams["font.family"] = "Times New Roman"
|
|
plt.rcParams["font.size"] = 16
|
|
return
|
|
|
|
|
|
@app.cell
|
|
def __(np, plt):
|
|
x = np.arange(4)
|
|
|
|
fig, ax = plt.subplots(figsize=(7, 2.6), ncols=3, layout="constrained")
|
|
|
|
c_1 = (139 / 255, 0 / 255, 0 / 255)
|
|
c_2 = (0, 0, 0)
|
|
c_3 = (191 / 255, 191 / 255, 191 / 255)
|
|
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
|
|
|
# tp = 8 * l * b * h * (d-1)/d
|
|
# pp = 2 * b * h * (d - 1)
|
|
token_size = 512 * 8 * 4 / 1024 / 1024 / 1024
|
|
x = [2, 3, 4, 5, 6, 7, 8]
|
|
|
|
l1 = 22
|
|
h1 = 2048
|
|
|
|
l2 = 32
|
|
h2 = 4096
|
|
|
|
l3 = 40
|
|
h3 = 5120
|
|
|
|
ticks = [2, 3, 4, 5, 6, 7, 8]
|
|
ticks_label = ["2", "3", "4", "5", "6", "7", "8"]
|
|
y_ticks = [0, 5, 10, 15, 20]
|
|
y_ticks_label = ["0", "5", "10", "15", "20"]
|
|
|
|
y_tp = [8 * l1 * token_size * h1 * (d - 1) / d for d in x]
|
|
y_pp = [2 * token_size * h1 * (d - 1) for d in x]
|
|
ax[0].plot(x, y_tp, color=c_2, label="TP", marker="o")
|
|
ax[0].plot(x, y_pp, color=c_4, label="LoRAPP and 1F1B", marker="*")
|
|
ax[0].set_ylim([-1, 25])
|
|
ax[0].set_xticks(ticks)
|
|
ax[0].set_xticklabels(ticks_label)
|
|
ax[0].set_yticks(y_ticks)
|
|
ax[0].set_yticklabels(y_ticks_label, rotation=90, ha="center", va="center")
|
|
ax[0].set_title("(a) TinyLlama-1.1B", fontsize=16)
|
|
ax[0].set_ylabel("Communication Cost (GB)", fontsize=16)
|
|
ax[0].set_xlabel("Number of GPUs", fontsize=16)
|
|
ax[0].tick_params(labelsize=16, pad=7)
|
|
|
|
y_tp = [8 * l2 * token_size * h2 * (d - 1) / d for d in x]
|
|
y_pp = [2 * token_size * h2 * (d - 1) for d in x]
|
|
ax[1].plot(x, y_tp, color=c_2, marker="o")
|
|
ax[1].plot(x, y_pp, color=c_4, marker="*")
|
|
ax[1].set_ylim([-1, 25])
|
|
ax[1].set_xticks(ticks)
|
|
ax[1].set_xticklabels(ticks_label)
|
|
ax[1].set_yticks(y_ticks)
|
|
ax[1].set_yticklabels(y_ticks_label, rotation=90, ha="center", va="center")
|
|
ax[1].set_title("(b) Llama-2-7B", fontsize=16)
|
|
ax[1].set_xlabel("Number of GPUs", fontsize=16)
|
|
ax[1].tick_params(labelsize=16, pad=7)
|
|
|
|
y_tp = [8 * l3 * token_size * h3 * (d - 1) / d for d in x]
|
|
y_pp = [2 * token_size * h3 * (d - 1) for d in x]
|
|
ax[2].plot(x, y_tp, color=c_2, marker="o")
|
|
ax[2].plot(x, y_pp, color=c_4, marker="*")
|
|
ax[2].set_ylim([-1, 25])
|
|
ax[2].set_xticks(ticks)
|
|
ax[2].set_xticklabels(ticks_label)
|
|
ax[2].set_yticks(y_ticks)
|
|
ax[2].set_yticklabels(y_ticks_label, rotation=90, ha="center", va="center")
|
|
ax[2].set_title("(c) Llama-2-13B", fontsize=16)
|
|
ax[2].tick_params(labelsize=16, pad=7)
|
|
ax[2].set_xlabel("Number of GPUs", fontsize=16)
|
|
|
|
fig.legend(
|
|
ncol=2,
|
|
bbox_to_anchor=(0.8, 1.17),
|
|
fancybox=False,
|
|
framealpha=0.0,
|
|
fontsize=16,
|
|
)
|
|
|
|
# plt.show()
|
|
plt.savefig("pp_cmp_com_cost.pdf", bbox_inches="tight", dpi=1000)
|
|
return (
|
|
ax,
|
|
c_1,
|
|
c_2,
|
|
c_3,
|
|
c_4,
|
|
fig,
|
|
h1,
|
|
h2,
|
|
h3,
|
|
l1,
|
|
l2,
|
|
l3,
|
|
ticks,
|
|
ticks_label,
|
|
token_size,
|
|
x,
|
|
y_pp,
|
|
y_ticks,
|
|
y_ticks_label,
|
|
y_tp,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|