import marimo __generated_with = "0.9.17" app = marimo.App(width="medium") @app.cell def __(): import matplotlib.pyplot as plt import numpy as np return np, plt @app.cell def __(plt): plt.rcParams['font.family'] = 'Times New Roman' plt.rcParams['font.size'] = 16 return @app.cell def __(np, plt): x = np.arange(4) fig, ax = plt.subplots(figsize=(7, 2), ncols=3, layout="constrained") c_1 = (139 / 255, 0 / 255, 0 / 255) c_2 = (0, 0, 0) c_3 = (191 / 255, 191 / 255, 191 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) x_b1 = [1, 2, 3, 4] b1_throughput_mLoRA = [4600.45, 8664.91, 10118.36, 10184.44] b1_throughput_tp = [5752.14, 5749.34, 5756.78, 5758.32] b1_throughput_fsdp = [6151.91, 6141.73, 6161.23, 6153.93] b1_throughput_gpipe = [4599.87, 4610.19, 4592.17, 4601.18] x_b7 = [1, 2, 3, 4] b7_throughput_mLoRA = [1274.87, 2250.46, 2362.69, 2363.89] b7_throughput_tp = [1614.18, 1610.26, 1620.07, 1613.34] b7_throughput_fsdp = [1695.37, 1705.97, 1686.05, 1693.45] b7_throughput_gpipe = [1284.27, 1273.89, 1272.14, 1279.64] x_b13 = [1, 2, 3, 4] b13_throughput_mLoRA = [723.21, 1280.54, 1286.54, 1282.54] b13_throughput_tp = [875, 877, 870, 878] b13_throughput_fsdp = [0, 0, 0, 0] # for the error b13_throughput_gpipe = [723.21, 719.21, 726.21, 723.21] ax[0].set_ylabel("Throughput (tokens/s)", fontsize=16) ax[0].plot(x_b1, b1_throughput_fsdp, color=c_3, marker="v") ax[0].plot(x_b1, b1_throughput_tp, color=c_2, marker="o") ax[0].plot(x_b1, b1_throughput_gpipe, color=c_1, marker="^") ax[0].plot(x_b1, b1_throughput_mLoRA, color=c_4, marker="*") ax[0].set_ylim(3000, 12000) ax[0].set_xticks([1, 2, 3, 4]) ax[0].set_xticklabels(["1", "2", "3", "4"]) ax[0].set_yticks([5000, 7000, 9000, 11000]) ax[0].set_yticklabels( ["5k", "7k", "9k", "11k"], rotation=90, ha="center", va="center", fontsize=12 ) ax[0].tick_params(pad=7) ax[0].set_title("(a) TinyLlama-1.1B", fontsize=16) ax[1].plot(x_b7, b7_throughput_gpipe, color=c_1, label="1F1B", marker="^") ax[1].plot(x_b7, b7_throughput_tp, color=c_2, label="TP", marker="o") ax[1].plot(x_b7, b7_throughput_fsdp, color=c_3, label="FSDP", marker="v") ax[1].plot(x_b7, b7_throughput_mLoRA, color=c_4, label="LoRAPP", marker="*") ax[1].set_ylim(1000, 2500) ax[1].set_xticks([1, 2, 3, 4]) ax[1].set_xticklabels(["1", "2", "3", "4"]) ax[1].set_yticks([1200, 1700, 2200]) ax[1].set_yticklabels( ["1200", "1700", "2200"], rotation=90, ha="center", va="center", fontsize=12 ) ax[1].tick_params(pad=7) ax[1].set_title("(b) Llama-2-7B", fontsize=16) ax[1].set_xlabel("Number of simultaneously trained LoRA adapters", fontsize=16) ax[2].plot( x_b13, b13_throughput_fsdp, color=c_3, marker="x", markerfacecolor="r", markeredgecolor="r", ) ax[2].plot(x_b13, b13_throughput_tp, color=c_2, marker="o") ax[2].plot(x_b13, b13_throughput_gpipe, color=c_1, marker="^") ax[2].plot(x_b13, b13_throughput_mLoRA, color=c_4, marker="*") ax[2].set_ylim(400, 1500) ax[2].set_xticks([1, 2, 3, 4]) ax[2].set_xticklabels(["1", "2", "3", "4"]) ax[2].set_yticks([600, 950, 1300]) ax[2].set_yticklabels( ["650", "950", "1300"], rotation=90, ha="center", va="center", fontsize=12 ) ax[2].tick_params(pad=7) ax[2].set_title("(c) Llama-2-13B", fontsize=16) ax[2].text( 0.95, 0.05, "FSDP : OOM", fontsize=14, va="bottom", ha="right", transform=ax[2].transAxes, color="r", style="italic", ) fig.legend( ncol=4, bbox_to_anchor=(0.97, 1.2), fancybox=False, framealpha=0.0, fontsize=16, ) # plt.show() plt.savefig("lorapp-task.pdf", bbox_inches="tight", dpi=1000) return ( ax, b13_throughput_fsdp, b13_throughput_gpipe, b13_throughput_mLoRA, b13_throughput_tp, b1_throughput_fsdp, b1_throughput_gpipe, b1_throughput_mLoRA, b1_throughput_tp, b7_throughput_fsdp, b7_throughput_gpipe, b7_throughput_mLoRA, b7_throughput_tp, c_1, c_2, c_3, c_4, fig, x, x_b1, x_b13, x_b7, ) if __name__ == "__main__": app.run()