94 lines
2.2 KiB
Python
94 lines
2.2 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.9.17"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell
|
|
def __():
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import random
|
|
|
|
plt.rcParams['font.family'] = 'Times New Roman'
|
|
plt.rcParams['font.size'] = 16
|
|
return np, plt, random
|
|
|
|
|
|
@app.cell
|
|
def __(plt):
|
|
fig, ax = plt.subplots(figsize=(7, 2.8), ncols=3, layout="constrained")
|
|
|
|
space_width = 3 / 22
|
|
bar_width = 3 * space_width
|
|
|
|
c_1 = (139 / 255, 0 / 255, 0 / 255)
|
|
c_2 = (0, 0, 0)
|
|
c_3 = (191 / 255, 191 / 255, 191 / 255)
|
|
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
|
|
|
y_0 = [10525.52, 10467.51, 10315.85, 10286.17]
|
|
x_0 = [4, 8, 16, 32]
|
|
|
|
y_1 = [2309.40, 2258.73, 2252.43, 2242.80]
|
|
x_1 = [16, 32, 64, 128]
|
|
|
|
y_2 = [1245.79, 1244.60, 1224.91, 1207.40]
|
|
x_2 = [16, 32, 64, 128]
|
|
|
|
ax[0].plot(x_0, y_0, "s-", color=c_1)
|
|
ax[0].set_ylim(0, 11000)
|
|
ax[0].set_xticks(x_0)
|
|
ax[0].set_yticks([0, 5000, 10000])
|
|
ax[0].set_yticklabels(
|
|
["0", "5k", "10k"], rotation=90, ha="center", va="center"
|
|
)
|
|
|
|
ax[1].plot(x_1, y_1, "s-", color=c_1)
|
|
ax[1].set_ylim(0, 11000)
|
|
ax[1].set_xticks([32, 64, 128])
|
|
ax[1].set_yticks([0, 5000, 10000])
|
|
ax[1].set_yticklabels(
|
|
["0", "5k", "10k"], rotation=90, ha="center", va="center"
|
|
)
|
|
|
|
ax[2].plot(x_2, y_2, "s-", color=c_1)
|
|
ax[2].set_ylim(0, 11000)
|
|
ax[2].set_xticks([32, 64, 128])
|
|
ax[2].set_yticks([0, 5000, 10000])
|
|
ax[2].set_yticklabels(
|
|
["0", "5k", "10k"], rotation=90, ha="center", va="center"
|
|
)
|
|
|
|
ax[0].set_ylabel("Throughput (tokens/s)", fontsize=16)
|
|
|
|
ax[0].set_title("(d) TinyLlama-1.1B", fontsize=16)
|
|
ax[1].set_title("(e) Llama2-7B", fontsize=16)
|
|
ax[2].set_title("(f) Llama2-13B", fontsize=16)
|
|
|
|
ax[0].set_xlabel("Rank of LoRA adapters", fontsize=14)
|
|
ax[1].set_xlabel("Rank of LoRA adapters", fontsize=14)
|
|
ax[2].set_xlabel("Rank of LoRA adapters", fontsize=14)
|
|
|
|
#plt.savefig("lora_rank_adaptability.pdf", bbox_inches="tight", dpi=1000)
|
|
return (
|
|
ax,
|
|
bar_width,
|
|
c_1,
|
|
c_2,
|
|
c_3,
|
|
c_4,
|
|
fig,
|
|
space_width,
|
|
x_0,
|
|
x_1,
|
|
x_2,
|
|
y_0,
|
|
y_1,
|
|
y_2,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|