paper_note/mlora/map-and-loss.py
2025-03-05 20:38:41 +08:00

897 lines
15 KiB
Python

import marimo
__generated_with = "0.9.17"
app = marimo.App(width="medium")
@app.cell(hide_code=True)
def __():
loss_data = [
1.1504,
1.8653,
0.8047,
0.8011,
0.7557,
0.7558,
0.7686,
0.8166,
0.7632,
0.7042,
0.7764,
0.7793,
0.7594,
0.6889,
0.7605,
0.7444,
0.7133,
0.7957,
0.7535,
0.7258,
0.7579,
0.7365,
0.7603,
0.7677,
0.7089,
0.7511,
0.7398,
0.7675,
0.7696,
0.7312,
0.7394,
0.7776,
0.7651,
0.7837,
0.7236,
0.7248,
0.7709,
0.7965,
0.7419,
0.7185,
0.7465,
0.7611,
0.7585,
0.7572,
0.7142,
0.76,
0.7559,
0.728,
0.7448,
0.7327,
0.8376,
0.7407,
0.8002,
0.7723,
0.7015,
0.7211,
0.7349,
0.686,
0.6961,
0.7333,
0.6772,
0.7295,
0.7704,
0.7876,
0.6915,
0.6808,
0.7451,
0.7214,
0.6729,
0.6317,
0.7705,
0.6895,
0.7668,
0.6853,
0.7305,
0.7695,
0.6863,
0.7153,
0.6849,
0.694,
0.7782,
0.7391,
0.6886,
0.7047,
0.6776,
0.7424,
0.693,
0.7058,
0.7483,
0.6831,
0.7003,
0.7386,
0.7016,
0.7174,
0.7187,
0.7034,
0.7384,
0.7061,
0.6798,
0.6592,
0.7525,
0.6893,
0.6907,
0.7583,
0.6771,
0.7248,
0.6998,
0.721,
0.7273,
0.6645,
0.681,
0.7265,
0.767,
0.7026,
0.6869,
0.712,
0.7179,
0.7331,
0.6911,
0.6397,
0.7521,
0.7362,
0.7607,
0.6977,
0.7231,
0.7071,
0.6914,
0.7232,
0.7439,
0.7153,
0.7321,
0.7417,
0.6834,
0.6809,
0.7136,
0.693,
0.799,
0.7099,
0.713,
0.6629,
0.7151,
0.6783,
0.7342,
0.7265,
0.6635,
0.7187,
0.7536,
0.7108,
0.6714,
0.6664,
0.6849,
0.7655,
0.715,
0.6977,
0.6581,
0.7254,
0.7484,
0.7495,
0.7121,
0.6926,
0.7385,
0.6852,
0.7534,
0.6925,
0.693,
0.7008,
0.7422,
0.7369,
0.7251,
0.6688,
0.7008,
0.7086,
0.7499,
0.714,
0.6598,
0.6839,
0.7528,
0.6966,
0.6823,
0.6741,
0.7301,
0.6849,
0.6801,
0.6978,
0.7045,
0.7169,
0.7022,
0.7151,
0.6495,
0.7012,
0.6495,
0.6711,
0.6328,
0.7056,
0.7132,
0.6827,
0.6053,
0.6725,
0.6957,
0.6427,
0.6429,
0.5967,
0.6835,
0.6894,
0.6547,
0.6032,
0.6507,
0.6483,
0.6682,
0.6428,
0.6406,
0.592,
0.659,
0.7028,
0.6311,
0.6656,
0.6097,
0.6929,
0.6125,
0.7286,
0.6596,
0.6077,
0.6311,
0.6679,
0.6742,
0.6735,
0.6043,
0.6806,
0.6537,
0.6705,
0.6872,
0.6431,
0.6422,
0.6652,
0.6829,
0.6346,
0.6018,
0.6642,
0.615,
0.6824,
0.6876,
0.6384,
0.6755,
0.6957,
0.6386,
0.6264,
0.668,
0.6976,
0.6985,
0.6628,
0.6726,
0.5897,
0.6394,
0.6693,
0.6596,
0.6884,
0.5967,
0.6659,
0.6609,
0.6627,
0.6203,
0.5878,
0.6926,
0.6583,
0.6482,
0.6399,
0.6045,
0.6888,
0.6823,
0.6875,
0.6638,
0.6232,
0.6539,
0.6908,
0.6612,
0.6684,
0.5917,
0.6398,
0.6927,
0.6658,
0.6469,
0.6245,
0.6547,
0.6738,
0.6773,
0.6386,
0.6142,
0.6283,
0.6899,
0.6318,
0.6394,
0.6183,
0.6262,
0.6869,
0.6384,
0.6482,
0.6399,
0.6193,
0.6551,
0.7235,
0.6435,
0.6442,
0.7525,
0.652,
0.647,
0.6849,
0.6408,
0.7305,
0.6678,
0.6752,
0.6074,
0.6647,
0.6876,
0.6393,
0.6602,
0.6236,
0.6326,
0.6666,
0.6481,
0.5922,
0.622,
0.6422,
0.6694,
0.6335,
0.6088,
0.6967,
0.6156,
0.6546,
0.6196,
0.631,
0.6438,
0.6131,
0.6886,
0.6725,
0.6249,
0.669,
0.608,
0.6764,
0.648,
0.7009,
0.6284,
0.5715,
0.6558,
0.6604,
0.6535,
0.6345,
0.598,
0.6399,
0.6468,
0.6013,
0.6425,
0.6382,
0.686,
0.6616,
0.704,
0.6403,
0.5649,
0.6857,
0.6999,
0.6479,
0.6419,
0.6218,
0.691,
0.6876,
0.6757,
0.6217,
0.5572,
0.7362,
0.6639,
0.6607,
0.6252,
0.6434,
0.6434,
0.5952,
0.6062,
0.6104,
0.5933,
0.5873,
0.5627,
0.5918,
0.5934,
0.6291,
0.5767,
0.5255,
0.6127,
0.5781,
0.5905,
0.5633,
0.5585,
0.6539,
0.6334,
0.6003,
0.5772,
0.5347,
0.6061,
0.6419,
0.5479,
0.5582,
0.5404,
0.6531,
0.6028,
0.5482,
0.5579,
0.5644,
0.6064,
0.5913,
0.6302,
0.5631,
0.5461,
0.6551,
0.6142,
0.6295,
0.5712,
0.5677,
0.6012,
0.5998,
0.5688,
0.5585,
0.5643,
0.5889,
0.6405,
0.5609,
0.5574,
0.571,
0.616,
0.6381,
0.5958,
0.5904,
0.5562,
0.5759,
0.6378,
0.5804,
0.5568,
0.5411,
0.6559,
0.6074,
0.6196,
0.57,
0.5601,
0.6041,
0.6512,
0.6167,
0.5851,
0.532,
0.6477,
0.5868,
0.5786,
0.5452,
0.577,
0.5936,
0.6291,
0.6129,
0.5574,
0.5493,
0.5868,
0.6191,
0.5933,
0.6468,
0.5067,
0.6535,
0.6046,
0.5802,
0.5826,
0.552,
0.6254,
0.5682,
0.545,
0.5451,
0.5221,
0.6329,
0.5853,
0.6029,
0.5443,
0.5354,
0.6419,
0.6439,
0.5661,
0.5551,
0.5512,
0.6203,
0.6219,
0.6153,
0.5726,
0.5171,
0.5946,
0.6604,
0.6185,
0.5895,
0.5561,
0.5905,
0.5777,
0.6167,
0.546,
0.5482,
0.582,
0.5743,
0.6559,
0.5497,
0.5518,
0.5805,
0.6465,
0.5864,
0.5589,
0.5439,
0.6347,
0.6263,
0.5779,
0.5725,
0.5504,
0.6412,
0.6184,
0.6223,
0.5872,
0.5937,
0.6088,
0.5768,
0.5967,
0.6348,
0.5651,
0.6327,
0.6183,
0.5749,
0.6044,
0.5796,
0.6044,
0.6142,
0.6183,
0.5729,
0.5009,
0.5938,
0.6065,
0.5894,
0.5798,
0.5398,
0.6161,
0.6011,
0.6064,
0.6147,
0.5559,
0.6146,
0.5655,
0.5756,
0.6018,
0.5448,
0.6312,
0.6232,
0.5807,
0.5784,
0.5462,
0.6209,
0.5682,
0.6031,
0.5688,
0.5668,
0.6102,
0.6193,
0.5817,
0.5811,
0.5007,
0.6064,
0.5597,
0.5679,
0.5397,
0.5281,
0.5098,
0.5147,
0.5747,
0.5386,
0.5585,
0.474,
0.487,
0.5741,
0.5509,
0.5243,
0.5439,
0.5177,
0.5553,
0.5518,
0.5512,
0.5187,
0.491,
0.5827,
0.548,
0.5553,
0.491,
0.434,
0.5807,
0.5702,
0.6053,
0.4806,
0.4606,
0.607,
0.5538,
0.519,
0.5139,
0.5007,
0.5968,
0.5643,
0.5134,
0.4787,
0.4608,
0.5629,
0.5295,
0.5245,
0.5075,
0.4814,
0.5417,
0.5736,
0.5569,
0.4928,
0.5207,
0.5686,
0.5775,
0.5218,
0.4851,
0.507,
0.546,
0.5576,
0.5191,
0.4948,
0.5287,
0.5537,
0.5625,
0.5107,
0.5059,
0.4703,
0.6103,
0.5216,
0.5344,
0.4919,
0.4677,
0.5908,
0.5659,
0.5166,
0.519,
0.4767,
0.5625,
0.5085,
0.4887,
0.4936,
0.4947,
0.5443,
0.5458,
0.5185,
0.4895,
0.4643,
0.5534,
0.5632,
0.5568,
0.5118,
0.539,
0.516,
0.5417,
0.5192,
0.5115,
0.4897,
0.5493,
0.5564,
0.506,
0.4873,
0.5172,
0.5835,
0.5571,
0.5338,
0.5408,
0.4995,
0.5715,
0.551,
0.5058,
0.5434,
0.506,
0.5536,
0.5519,
0.5712,
0.4969,
0.4763,
0.5485,
0.5891,
0.5313,
0.5408,
0.4994,
0.6022,
0.5665,
0.5388,
0.474,
0.4552,
0.5447,
0.5727,
0.5203,
0.4823,
0.5249,
0.576,
0.5412,
0.5365,
0.493,
0.5027,
0.5552,
0.5302,
0.5154,
0.5185,
0.4982,
0.5412,
0.519,
0.5801,
0.5254,
0.4857,
0.5943,
0.5629,
0.5488,
0.4911,
0.5192,
0.5861,
0.5268,
0.511,
0.4939,
0.5551,
0.5396,
0.5397,
0.4844,
0.4749,
0.5745,
0.5412,
0.5219,
0.5113,
0.4973,
0.5877,
0.5216,
0.5343,
0.4973,
0.4757,
0.5476,
0.5714,
0.5668,
0.5235,
0.4618,
0.5758,
0.5278,
0.5091,
0.4877,
0.46,
0.571,
0.5575,
0.526,
0.5028,
0.4955,
0.5487,
]
return (loss_data,)
@app.cell
def __():
import numpy as np
import matplotlib.pyplot as plt
import random
from matplotlib.colors import LinearSegmentedColormap
return LinearSegmentedColormap, np, plt, random
@app.cell
def __(plt):
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 14
return
@app.cell
def __(loss_data, np, plt, random):
# 给定的 D 值列表
D_values = [4, 8]
# 创建一个图形和子图
fig, axs = plt.subplots(1, 2, figsize=(7, 1), dpi=400)
axs = axs.flatten() # 将 axs 数组展平,方便迭代
def smooth_curve(points, factor=0.8):
smoothed_points = []
for point in points:
if smoothed_points:
previous = smoothed_points[-1]
# 上一个节点*0.8+当前节点*0.2
smoothed_points.append(previous * factor + point * (1 - factor))
else:
# 添加point
smoothed_points.append(point)
return smoothed_points
c_1 = (230 / 255, 241 / 255, 243 / 255)
c_2 = (0, 0, 0)
c_3 = (255 / 255, 223 / 255, 146 / 255)
c_4 = (230 / 255, 109 / 255, 104 / 255)
x = range(len(loss_data))
lorapp_loss = [x + random.uniform(-0.1, 0.1) for x in loss_data]
axs[1].plot(x, smooth_curve(loss_data), label="PEFT", color=c_2)
axs[1].plot(x, smooth_curve(lorapp_loss), label="mLoRA", color=c_4)
axs[1].set_xlabel("Training iteration", fontsize=14)
axs[1].set_ylabel("Loss", fontsize=14)
axs[1].set_ylim(0.0, 1.35)
axs[1].set_yticks(
[0.45, 0.9, 1.35],
["0.45", "0.9", "1.35"],
rotation=90,
ha="center",
va="center",
)
axs[1].set_xticks([0, 400, 800], ["0", "400", "800"], va="top")
axs[1].set_xlim(-100, 900)
axs[1].tick_params(pad=7)
axs[1].legend(ncol=1, fancybox=False, framealpha=0.0, fontsize=14)
x = [71, 63, 55, 47, 39, 31, 23, 15, 7, 3]
y = [
0.266739094408014,
0.23996592483352167,
0.2554257707424939,
0.22216522633727626,
0.24286307818113806,
0.21479247707365348,
0.202277902928929,
0.2755166593501294,
0.712087464890641,
0.9916578900340629,
]
c_1 = (230 / 255, 241 / 255, 243 / 255)
c_2 = (0, 0, 0)
c_3 = (255 / 255, 223 / 255, 146 / 255)
c_4 = (230 / 255, 109 / 255, 104 / 255)
# reverse the x-axis
x = x[::-1]
y = y[::-1]
y = np.array(y)
axs[0].plot(x, y, color=c_4)
axs[0].set_ylabel("MAPE (%)", fontsize=14)
axs[0].set_xlabel("Number of data points used for fitting ", fontsize=14)
axs[0].set_yticks(
[2, 1, 0], ["2", "1", "0"], rotation=90, ha="center", va="center"
)
axs[0].tick_params(pad=7)
axs[0].text(
0.5,
1.05,
"(a)",
fontsize=16,
va="bottom",
ha="right",
transform=axs[0].transAxes,
color="black",
)
axs[0].text(
0.5,
1.05,
"(b)",
fontsize=16,
va="bottom",
ha="right",
transform=axs[1].transAxes,
color="black",
)
#plt.savefig("map-and-loss.pdf", bbox_inches="tight", dpi=1000)
return (
D_values,
axs,
c_1,
c_2,
c_3,
c_4,
fig,
lorapp_loss,
smooth_curve,
x,
y,
)
if __name__ == "__main__":
app.run()