import marimo __generated_with = "0.9.17" app = marimo.App(width="medium") @app.cell(hide_code=True) def __(): loss_data = [ 1.1504, 1.8653, 0.8047, 0.8011, 0.7557, 0.7558, 0.7686, 0.8166, 0.7632, 0.7042, 0.7764, 0.7793, 0.7594, 0.6889, 0.7605, 0.7444, 0.7133, 0.7957, 0.7535, 0.7258, 0.7579, 0.7365, 0.7603, 0.7677, 0.7089, 0.7511, 0.7398, 0.7675, 0.7696, 0.7312, 0.7394, 0.7776, 0.7651, 0.7837, 0.7236, 0.7248, 0.7709, 0.7965, 0.7419, 0.7185, 0.7465, 0.7611, 0.7585, 0.7572, 0.7142, 0.76, 0.7559, 0.728, 0.7448, 0.7327, 0.8376, 0.7407, 0.8002, 0.7723, 0.7015, 0.7211, 0.7349, 0.686, 0.6961, 0.7333, 0.6772, 0.7295, 0.7704, 0.7876, 0.6915, 0.6808, 0.7451, 0.7214, 0.6729, 0.6317, 0.7705, 0.6895, 0.7668, 0.6853, 0.7305, 0.7695, 0.6863, 0.7153, 0.6849, 0.694, 0.7782, 0.7391, 0.6886, 0.7047, 0.6776, 0.7424, 0.693, 0.7058, 0.7483, 0.6831, 0.7003, 0.7386, 0.7016, 0.7174, 0.7187, 0.7034, 0.7384, 0.7061, 0.6798, 0.6592, 0.7525, 0.6893, 0.6907, 0.7583, 0.6771, 0.7248, 0.6998, 0.721, 0.7273, 0.6645, 0.681, 0.7265, 0.767, 0.7026, 0.6869, 0.712, 0.7179, 0.7331, 0.6911, 0.6397, 0.7521, 0.7362, 0.7607, 0.6977, 0.7231, 0.7071, 0.6914, 0.7232, 0.7439, 0.7153, 0.7321, 0.7417, 0.6834, 0.6809, 0.7136, 0.693, 0.799, 0.7099, 0.713, 0.6629, 0.7151, 0.6783, 0.7342, 0.7265, 0.6635, 0.7187, 0.7536, 0.7108, 0.6714, 0.6664, 0.6849, 0.7655, 0.715, 0.6977, 0.6581, 0.7254, 0.7484, 0.7495, 0.7121, 0.6926, 0.7385, 0.6852, 0.7534, 0.6925, 0.693, 0.7008, 0.7422, 0.7369, 0.7251, 0.6688, 0.7008, 0.7086, 0.7499, 0.714, 0.6598, 0.6839, 0.7528, 0.6966, 0.6823, 0.6741, 0.7301, 0.6849, 0.6801, 0.6978, 0.7045, 0.7169, 0.7022, 0.7151, 0.6495, 0.7012, 0.6495, 0.6711, 0.6328, 0.7056, 0.7132, 0.6827, 0.6053, 0.6725, 0.6957, 0.6427, 0.6429, 0.5967, 0.6835, 0.6894, 0.6547, 0.6032, 0.6507, 0.6483, 0.6682, 0.6428, 0.6406, 0.592, 0.659, 0.7028, 0.6311, 0.6656, 0.6097, 0.6929, 0.6125, 0.7286, 0.6596, 0.6077, 0.6311, 0.6679, 0.6742, 0.6735, 0.6043, 0.6806, 0.6537, 0.6705, 0.6872, 0.6431, 0.6422, 0.6652, 0.6829, 0.6346, 0.6018, 0.6642, 0.615, 0.6824, 0.6876, 0.6384, 0.6755, 0.6957, 0.6386, 0.6264, 0.668, 0.6976, 0.6985, 0.6628, 0.6726, 0.5897, 0.6394, 0.6693, 0.6596, 0.6884, 0.5967, 0.6659, 0.6609, 0.6627, 0.6203, 0.5878, 0.6926, 0.6583, 0.6482, 0.6399, 0.6045, 0.6888, 0.6823, 0.6875, 0.6638, 0.6232, 0.6539, 0.6908, 0.6612, 0.6684, 0.5917, 0.6398, 0.6927, 0.6658, 0.6469, 0.6245, 0.6547, 0.6738, 0.6773, 0.6386, 0.6142, 0.6283, 0.6899, 0.6318, 0.6394, 0.6183, 0.6262, 0.6869, 0.6384, 0.6482, 0.6399, 0.6193, 0.6551, 0.7235, 0.6435, 0.6442, 0.7525, 0.652, 0.647, 0.6849, 0.6408, 0.7305, 0.6678, 0.6752, 0.6074, 0.6647, 0.6876, 0.6393, 0.6602, 0.6236, 0.6326, 0.6666, 0.6481, 0.5922, 0.622, 0.6422, 0.6694, 0.6335, 0.6088, 0.6967, 0.6156, 0.6546, 0.6196, 0.631, 0.6438, 0.6131, 0.6886, 0.6725, 0.6249, 0.669, 0.608, 0.6764, 0.648, 0.7009, 0.6284, 0.5715, 0.6558, 0.6604, 0.6535, 0.6345, 0.598, 0.6399, 0.6468, 0.6013, 0.6425, 0.6382, 0.686, 0.6616, 0.704, 0.6403, 0.5649, 0.6857, 0.6999, 0.6479, 0.6419, 0.6218, 0.691, 0.6876, 0.6757, 0.6217, 0.5572, 0.7362, 0.6639, 0.6607, 0.6252, 0.6434, 0.6434, 0.5952, 0.6062, 0.6104, 0.5933, 0.5873, 0.5627, 0.5918, 0.5934, 0.6291, 0.5767, 0.5255, 0.6127, 0.5781, 0.5905, 0.5633, 0.5585, 0.6539, 0.6334, 0.6003, 0.5772, 0.5347, 0.6061, 0.6419, 0.5479, 0.5582, 0.5404, 0.6531, 0.6028, 0.5482, 0.5579, 0.5644, 0.6064, 0.5913, 0.6302, 0.5631, 0.5461, 0.6551, 0.6142, 0.6295, 0.5712, 0.5677, 0.6012, 0.5998, 0.5688, 0.5585, 0.5643, 0.5889, 0.6405, 0.5609, 0.5574, 0.571, 0.616, 0.6381, 0.5958, 0.5904, 0.5562, 0.5759, 0.6378, 0.5804, 0.5568, 0.5411, 0.6559, 0.6074, 0.6196, 0.57, 0.5601, 0.6041, 0.6512, 0.6167, 0.5851, 0.532, 0.6477, 0.5868, 0.5786, 0.5452, 0.577, 0.5936, 0.6291, 0.6129, 0.5574, 0.5493, 0.5868, 0.6191, 0.5933, 0.6468, 0.5067, 0.6535, 0.6046, 0.5802, 0.5826, 0.552, 0.6254, 0.5682, 0.545, 0.5451, 0.5221, 0.6329, 0.5853, 0.6029, 0.5443, 0.5354, 0.6419, 0.6439, 0.5661, 0.5551, 0.5512, 0.6203, 0.6219, 0.6153, 0.5726, 0.5171, 0.5946, 0.6604, 0.6185, 0.5895, 0.5561, 0.5905, 0.5777, 0.6167, 0.546, 0.5482, 0.582, 0.5743, 0.6559, 0.5497, 0.5518, 0.5805, 0.6465, 0.5864, 0.5589, 0.5439, 0.6347, 0.6263, 0.5779, 0.5725, 0.5504, 0.6412, 0.6184, 0.6223, 0.5872, 0.5937, 0.6088, 0.5768, 0.5967, 0.6348, 0.5651, 0.6327, 0.6183, 0.5749, 0.6044, 0.5796, 0.6044, 0.6142, 0.6183, 0.5729, 0.5009, 0.5938, 0.6065, 0.5894, 0.5798, 0.5398, 0.6161, 0.6011, 0.6064, 0.6147, 0.5559, 0.6146, 0.5655, 0.5756, 0.6018, 0.5448, 0.6312, 0.6232, 0.5807, 0.5784, 0.5462, 0.6209, 0.5682, 0.6031, 0.5688, 0.5668, 0.6102, 0.6193, 0.5817, 0.5811, 0.5007, 0.6064, 0.5597, 0.5679, 0.5397, 0.5281, 0.5098, 0.5147, 0.5747, 0.5386, 0.5585, 0.474, 0.487, 0.5741, 0.5509, 0.5243, 0.5439, 0.5177, 0.5553, 0.5518, 0.5512, 0.5187, 0.491, 0.5827, 0.548, 0.5553, 0.491, 0.434, 0.5807, 0.5702, 0.6053, 0.4806, 0.4606, 0.607, 0.5538, 0.519, 0.5139, 0.5007, 0.5968, 0.5643, 0.5134, 0.4787, 0.4608, 0.5629, 0.5295, 0.5245, 0.5075, 0.4814, 0.5417, 0.5736, 0.5569, 0.4928, 0.5207, 0.5686, 0.5775, 0.5218, 0.4851, 0.507, 0.546, 0.5576, 0.5191, 0.4948, 0.5287, 0.5537, 0.5625, 0.5107, 0.5059, 0.4703, 0.6103, 0.5216, 0.5344, 0.4919, 0.4677, 0.5908, 0.5659, 0.5166, 0.519, 0.4767, 0.5625, 0.5085, 0.4887, 0.4936, 0.4947, 0.5443, 0.5458, 0.5185, 0.4895, 0.4643, 0.5534, 0.5632, 0.5568, 0.5118, 0.539, 0.516, 0.5417, 0.5192, 0.5115, 0.4897, 0.5493, 0.5564, 0.506, 0.4873, 0.5172, 0.5835, 0.5571, 0.5338, 0.5408, 0.4995, 0.5715, 0.551, 0.5058, 0.5434, 0.506, 0.5536, 0.5519, 0.5712, 0.4969, 0.4763, 0.5485, 0.5891, 0.5313, 0.5408, 0.4994, 0.6022, 0.5665, 0.5388, 0.474, 0.4552, 0.5447, 0.5727, 0.5203, 0.4823, 0.5249, 0.576, 0.5412, 0.5365, 0.493, 0.5027, 0.5552, 0.5302, 0.5154, 0.5185, 0.4982, 0.5412, 0.519, 0.5801, 0.5254, 0.4857, 0.5943, 0.5629, 0.5488, 0.4911, 0.5192, 0.5861, 0.5268, 0.511, 0.4939, 0.5551, 0.5396, 0.5397, 0.4844, 0.4749, 0.5745, 0.5412, 0.5219, 0.5113, 0.4973, 0.5877, 0.5216, 0.5343, 0.4973, 0.4757, 0.5476, 0.5714, 0.5668, 0.5235, 0.4618, 0.5758, 0.5278, 0.5091, 0.4877, 0.46, 0.571, 0.5575, 0.526, 0.5028, 0.4955, 0.5487, ] return (loss_data,) @app.cell def __(): import numpy as np import matplotlib.pyplot as plt import random from matplotlib.colors import LinearSegmentedColormap return LinearSegmentedColormap, np, plt, random @app.cell def __(plt): plt.rcParams['font.family'] = 'Times New Roman' plt.rcParams['font.size'] = 14 return @app.cell def __(loss_data, np, plt, random): # 给定的 D 值列表 D_values = [4, 8] # 创建一个图形和子图 fig, axs = plt.subplots(1, 2, figsize=(7, 1), dpi=400) axs = axs.flatten() # 将 axs 数组展平,方便迭代 def smooth_curve(points, factor=0.8): smoothed_points = [] for point in points: if smoothed_points: previous = smoothed_points[-1] # 上一个节点*0.8+当前节点*0.2 smoothed_points.append(previous * factor + point * (1 - factor)) else: # 添加point smoothed_points.append(point) return smoothed_points c_1 = (230 / 255, 241 / 255, 243 / 255) c_2 = (0, 0, 0) c_3 = (255 / 255, 223 / 255, 146 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) x = range(len(loss_data)) lorapp_loss = [x + random.uniform(-0.1, 0.1) for x in loss_data] axs[1].plot(x, smooth_curve(loss_data), label="PEFT", color=c_2) axs[1].plot(x, smooth_curve(lorapp_loss), label="mLoRA", color=c_4) axs[1].set_xlabel("Training iteration", fontsize=14) axs[1].set_ylabel("Loss", fontsize=14) axs[1].set_ylim(0.0, 1.35) axs[1].set_yticks( [0.45, 0.9, 1.35], ["0.45", "0.9", "1.35"], rotation=90, ha="center", va="center", ) axs[1].set_xticks([0, 400, 800], ["0", "400", "800"], va="top") axs[1].set_xlim(-100, 900) axs[1].tick_params(pad=7) axs[1].legend(ncol=1, fancybox=False, framealpha=0.0, fontsize=14) x = [71, 63, 55, 47, 39, 31, 23, 15, 7, 3] y = [ 0.266739094408014, 0.23996592483352167, 0.2554257707424939, 0.22216522633727626, 0.24286307818113806, 0.21479247707365348, 0.202277902928929, 0.2755166593501294, 0.712087464890641, 0.9916578900340629, ] c_1 = (230 / 255, 241 / 255, 243 / 255) c_2 = (0, 0, 0) c_3 = (255 / 255, 223 / 255, 146 / 255) c_4 = (230 / 255, 109 / 255, 104 / 255) # reverse the x-axis x = x[::-1] y = y[::-1] y = np.array(y) axs[0].plot(x, y, color=c_4) axs[0].set_ylabel("MAPE (%)", fontsize=14) axs[0].set_xlabel("Number of data points used for fitting ", fontsize=14) axs[0].set_yticks( [2, 1, 0], ["2", "1", "0"], rotation=90, ha="center", va="center" ) axs[0].tick_params(pad=7) axs[0].text( 0.5, 1.05, "(a)", fontsize=16, va="bottom", ha="right", transform=axs[0].transAxes, color="black", ) axs[0].text( 0.5, 1.05, "(b)", fontsize=16, va="bottom", ha="right", transform=axs[1].transAxes, color="black", ) #plt.savefig("map-and-loss.pdf", bbox_inches="tight", dpi=1000) return ( D_values, axs, c_1, c_2, c_3, c_4, fig, lorapp_loss, smooth_curve, x, y, ) if __name__ == "__main__": app.run()