185 lines
3.8 KiB
Python
185 lines
3.8 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.9.17"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell
|
|
def __():
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
plt.rcParams['font.family'] = 'Times New Roman'
|
|
plt.rcParams['font.size'] = 16
|
|
|
|
c_1 = (139 / 255, 0 / 255, 0 / 255)
|
|
c_2 = (0, 0, 0)
|
|
c_3 = (191 / 255, 191 / 255, 191 / 255)
|
|
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
|
return c_1, c_2, c_3, c_4, np, plt
|
|
|
|
|
|
@app.cell
|
|
def __(c_2, c_4, plt):
|
|
fig, ax = plt.subplots(1, 2, figsize=(7, 3.5), constrained_layout=True)
|
|
|
|
space_width = 3 / 22
|
|
bar_width = 3 * space_width
|
|
|
|
x_ticks = [
|
|
bar_width,
|
|
space_width + 3 * bar_width,
|
|
2 * space_width + 5 * bar_width,
|
|
]
|
|
x_ticks_label = ["1.1B", "7B", "13B"]
|
|
|
|
|
|
ax[1].bar(
|
|
space_width,
|
|
45,
|
|
bar_width,
|
|
color=c_2,
|
|
)
|
|
ax[1].bar(
|
|
space_width + bar_width,
|
|
55,
|
|
bar_width,
|
|
color=c_4,
|
|
label="Communication",
|
|
)
|
|
|
|
ax[1].bar(
|
|
2 * space_width + 2 * bar_width,
|
|
75,
|
|
bar_width,
|
|
color=c_2,
|
|
label="Computation",
|
|
)
|
|
ax[1].bar(
|
|
2 * space_width + 3 * bar_width,
|
|
25,
|
|
bar_width,
|
|
color=c_4,
|
|
)
|
|
|
|
ax[1].bar(
|
|
3 * space_width + 4 * bar_width,
|
|
85,
|
|
bar_width,
|
|
color=c_2,
|
|
)
|
|
ax[1].bar(
|
|
3 * space_width + 5 * bar_width,
|
|
15,
|
|
bar_width,
|
|
color=c_4,
|
|
)
|
|
|
|
ax[1].set_xticks(x_ticks)
|
|
ax[1].set_xticklabels(x_ticks_label)
|
|
|
|
ax[1].tick_params(bottom=False, labelsize=14, pad=7)
|
|
ax[1].set_ylabel("The proportion of \nthe total time (%)", fontsize=16)
|
|
|
|
|
|
ax[1].set_yticks([0, 50, 100])
|
|
ax[1].set_yticklabels(
|
|
["0", "50", "100"], rotation=90, ha="center", va="center"
|
|
)
|
|
ax[1].set_title("(b)", fontsize=16)
|
|
|
|
ax[1].set_xlabel("Base model with \ndifferent parameter scales", fontsize=16)
|
|
|
|
##### ##### #####
|
|
|
|
space_width = 3 / 22
|
|
bar_width = 3 * space_width
|
|
|
|
x_ticks = [
|
|
bar_width,
|
|
space_width + 3 * bar_width,
|
|
2 * space_width + 5 * bar_width,
|
|
]
|
|
x_ticks_label = ["1.1B", "7B", "13B"]
|
|
|
|
|
|
ax[0].bar(
|
|
space_width,
|
|
7062.16,
|
|
bar_width,
|
|
color=c_2,
|
|
)
|
|
ax[0].bar(
|
|
space_width + bar_width,
|
|
11625,
|
|
bar_width,
|
|
color=c_4,
|
|
label="With overlapping",
|
|
)
|
|
|
|
ax[0].bar(
|
|
2 * space_width + 2 * bar_width,
|
|
1874.48,
|
|
bar_width,
|
|
color=c_2,
|
|
label="Without overlapping",
|
|
)
|
|
ax[0].bar(
|
|
2 * space_width + 3 * bar_width,
|
|
2270,
|
|
bar_width,
|
|
color=c_4,
|
|
)
|
|
|
|
ax[0].bar(
|
|
3 * space_width + 4 * bar_width,
|
|
1102,
|
|
bar_width,
|
|
color=c_2,
|
|
)
|
|
ax[0].bar(
|
|
3 * space_width + 5 * bar_width,
|
|
1280,
|
|
bar_width,
|
|
color=c_4,
|
|
)
|
|
|
|
ax[0].set_xticks(x_ticks)
|
|
ax[0].set_xticklabels(x_ticks_label)
|
|
|
|
ax[0].tick_params(bottom=False, labelsize=14, pad=7)
|
|
ax[0].set_ylabel("Throughput (tokens/s)", fontsize=16)
|
|
|
|
|
|
ax[0].set_yticks([0, 2500, 5000, 7500, 10000])
|
|
ax[0].set_yticklabels(
|
|
["0", "2.5k", "5k", "7.5k", "10k"], rotation=90, ha="center", va="center"
|
|
)
|
|
ax[0].set_title("(a)", fontsize=16)
|
|
|
|
ax[0].set_xlabel("Base model with \ndifferent parameter scales", fontsize=16)
|
|
|
|
|
|
ax[1].legend(
|
|
ncol=1,
|
|
bbox_to_anchor=(0.55, 1),
|
|
fancybox=False,
|
|
framealpha=0.0,
|
|
fontsize=10,
|
|
)
|
|
|
|
ax[0].legend(
|
|
ncol=1,
|
|
bbox_to_anchor=(0.35, 1),
|
|
fancybox=False,
|
|
framealpha=0.0,
|
|
fontsize=10,
|
|
)
|
|
|
|
plt.savefig("overlapping.pdf", bbox_inches="tight", dpi=1000)
|
|
return ax, bar_width, fig, space_width, x_ticks, x_ticks_label
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|