897 lines
15 KiB
Python
897 lines
15 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.9.17"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell(hide_code=True)
|
|
def __():
|
|
loss_data = [
|
|
1.1504,
|
|
1.8653,
|
|
0.8047,
|
|
0.8011,
|
|
0.7557,
|
|
0.7558,
|
|
0.7686,
|
|
0.8166,
|
|
0.7632,
|
|
0.7042,
|
|
0.7764,
|
|
0.7793,
|
|
0.7594,
|
|
0.6889,
|
|
0.7605,
|
|
0.7444,
|
|
0.7133,
|
|
0.7957,
|
|
0.7535,
|
|
0.7258,
|
|
0.7579,
|
|
0.7365,
|
|
0.7603,
|
|
0.7677,
|
|
0.7089,
|
|
0.7511,
|
|
0.7398,
|
|
0.7675,
|
|
0.7696,
|
|
0.7312,
|
|
0.7394,
|
|
0.7776,
|
|
0.7651,
|
|
0.7837,
|
|
0.7236,
|
|
0.7248,
|
|
0.7709,
|
|
0.7965,
|
|
0.7419,
|
|
0.7185,
|
|
0.7465,
|
|
0.7611,
|
|
0.7585,
|
|
0.7572,
|
|
0.7142,
|
|
0.76,
|
|
0.7559,
|
|
0.728,
|
|
0.7448,
|
|
0.7327,
|
|
0.8376,
|
|
0.7407,
|
|
0.8002,
|
|
0.7723,
|
|
0.7015,
|
|
0.7211,
|
|
0.7349,
|
|
0.686,
|
|
0.6961,
|
|
0.7333,
|
|
0.6772,
|
|
0.7295,
|
|
0.7704,
|
|
0.7876,
|
|
0.6915,
|
|
0.6808,
|
|
0.7451,
|
|
0.7214,
|
|
0.6729,
|
|
0.6317,
|
|
0.7705,
|
|
0.6895,
|
|
0.7668,
|
|
0.6853,
|
|
0.7305,
|
|
0.7695,
|
|
0.6863,
|
|
0.7153,
|
|
0.6849,
|
|
0.694,
|
|
0.7782,
|
|
0.7391,
|
|
0.6886,
|
|
0.7047,
|
|
0.6776,
|
|
0.7424,
|
|
0.693,
|
|
0.7058,
|
|
0.7483,
|
|
0.6831,
|
|
0.7003,
|
|
0.7386,
|
|
0.7016,
|
|
0.7174,
|
|
0.7187,
|
|
0.7034,
|
|
0.7384,
|
|
0.7061,
|
|
0.6798,
|
|
0.6592,
|
|
0.7525,
|
|
0.6893,
|
|
0.6907,
|
|
0.7583,
|
|
0.6771,
|
|
0.7248,
|
|
0.6998,
|
|
0.721,
|
|
0.7273,
|
|
0.6645,
|
|
0.681,
|
|
0.7265,
|
|
0.767,
|
|
0.7026,
|
|
0.6869,
|
|
0.712,
|
|
0.7179,
|
|
0.7331,
|
|
0.6911,
|
|
0.6397,
|
|
0.7521,
|
|
0.7362,
|
|
0.7607,
|
|
0.6977,
|
|
0.7231,
|
|
0.7071,
|
|
0.6914,
|
|
0.7232,
|
|
0.7439,
|
|
0.7153,
|
|
0.7321,
|
|
0.7417,
|
|
0.6834,
|
|
0.6809,
|
|
0.7136,
|
|
0.693,
|
|
0.799,
|
|
0.7099,
|
|
0.713,
|
|
0.6629,
|
|
0.7151,
|
|
0.6783,
|
|
0.7342,
|
|
0.7265,
|
|
0.6635,
|
|
0.7187,
|
|
0.7536,
|
|
0.7108,
|
|
0.6714,
|
|
0.6664,
|
|
0.6849,
|
|
0.7655,
|
|
0.715,
|
|
0.6977,
|
|
0.6581,
|
|
0.7254,
|
|
0.7484,
|
|
0.7495,
|
|
0.7121,
|
|
0.6926,
|
|
0.7385,
|
|
0.6852,
|
|
0.7534,
|
|
0.6925,
|
|
0.693,
|
|
0.7008,
|
|
0.7422,
|
|
0.7369,
|
|
0.7251,
|
|
0.6688,
|
|
0.7008,
|
|
0.7086,
|
|
0.7499,
|
|
0.714,
|
|
0.6598,
|
|
0.6839,
|
|
0.7528,
|
|
0.6966,
|
|
0.6823,
|
|
0.6741,
|
|
0.7301,
|
|
0.6849,
|
|
0.6801,
|
|
0.6978,
|
|
0.7045,
|
|
0.7169,
|
|
0.7022,
|
|
0.7151,
|
|
0.6495,
|
|
0.7012,
|
|
0.6495,
|
|
0.6711,
|
|
0.6328,
|
|
0.7056,
|
|
0.7132,
|
|
0.6827,
|
|
0.6053,
|
|
0.6725,
|
|
0.6957,
|
|
0.6427,
|
|
0.6429,
|
|
0.5967,
|
|
0.6835,
|
|
0.6894,
|
|
0.6547,
|
|
0.6032,
|
|
0.6507,
|
|
0.6483,
|
|
0.6682,
|
|
0.6428,
|
|
0.6406,
|
|
0.592,
|
|
0.659,
|
|
0.7028,
|
|
0.6311,
|
|
0.6656,
|
|
0.6097,
|
|
0.6929,
|
|
0.6125,
|
|
0.7286,
|
|
0.6596,
|
|
0.6077,
|
|
0.6311,
|
|
0.6679,
|
|
0.6742,
|
|
0.6735,
|
|
0.6043,
|
|
0.6806,
|
|
0.6537,
|
|
0.6705,
|
|
0.6872,
|
|
0.6431,
|
|
0.6422,
|
|
0.6652,
|
|
0.6829,
|
|
0.6346,
|
|
0.6018,
|
|
0.6642,
|
|
0.615,
|
|
0.6824,
|
|
0.6876,
|
|
0.6384,
|
|
0.6755,
|
|
0.6957,
|
|
0.6386,
|
|
0.6264,
|
|
0.668,
|
|
0.6976,
|
|
0.6985,
|
|
0.6628,
|
|
0.6726,
|
|
0.5897,
|
|
0.6394,
|
|
0.6693,
|
|
0.6596,
|
|
0.6884,
|
|
0.5967,
|
|
0.6659,
|
|
0.6609,
|
|
0.6627,
|
|
0.6203,
|
|
0.5878,
|
|
0.6926,
|
|
0.6583,
|
|
0.6482,
|
|
0.6399,
|
|
0.6045,
|
|
0.6888,
|
|
0.6823,
|
|
0.6875,
|
|
0.6638,
|
|
0.6232,
|
|
0.6539,
|
|
0.6908,
|
|
0.6612,
|
|
0.6684,
|
|
0.5917,
|
|
0.6398,
|
|
0.6927,
|
|
0.6658,
|
|
0.6469,
|
|
0.6245,
|
|
0.6547,
|
|
0.6738,
|
|
0.6773,
|
|
0.6386,
|
|
0.6142,
|
|
0.6283,
|
|
0.6899,
|
|
0.6318,
|
|
0.6394,
|
|
0.6183,
|
|
0.6262,
|
|
0.6869,
|
|
0.6384,
|
|
0.6482,
|
|
0.6399,
|
|
0.6193,
|
|
0.6551,
|
|
0.7235,
|
|
0.6435,
|
|
0.6442,
|
|
0.7525,
|
|
0.652,
|
|
0.647,
|
|
0.6849,
|
|
0.6408,
|
|
0.7305,
|
|
0.6678,
|
|
0.6752,
|
|
0.6074,
|
|
0.6647,
|
|
0.6876,
|
|
0.6393,
|
|
0.6602,
|
|
0.6236,
|
|
0.6326,
|
|
0.6666,
|
|
0.6481,
|
|
0.5922,
|
|
0.622,
|
|
0.6422,
|
|
0.6694,
|
|
0.6335,
|
|
0.6088,
|
|
0.6967,
|
|
0.6156,
|
|
0.6546,
|
|
0.6196,
|
|
0.631,
|
|
0.6438,
|
|
0.6131,
|
|
0.6886,
|
|
0.6725,
|
|
0.6249,
|
|
0.669,
|
|
0.608,
|
|
0.6764,
|
|
0.648,
|
|
0.7009,
|
|
0.6284,
|
|
0.5715,
|
|
0.6558,
|
|
0.6604,
|
|
0.6535,
|
|
0.6345,
|
|
0.598,
|
|
0.6399,
|
|
0.6468,
|
|
0.6013,
|
|
0.6425,
|
|
0.6382,
|
|
0.686,
|
|
0.6616,
|
|
0.704,
|
|
0.6403,
|
|
0.5649,
|
|
0.6857,
|
|
0.6999,
|
|
0.6479,
|
|
0.6419,
|
|
0.6218,
|
|
0.691,
|
|
0.6876,
|
|
0.6757,
|
|
0.6217,
|
|
0.5572,
|
|
0.7362,
|
|
0.6639,
|
|
0.6607,
|
|
0.6252,
|
|
0.6434,
|
|
0.6434,
|
|
0.5952,
|
|
0.6062,
|
|
0.6104,
|
|
0.5933,
|
|
0.5873,
|
|
0.5627,
|
|
0.5918,
|
|
0.5934,
|
|
0.6291,
|
|
0.5767,
|
|
0.5255,
|
|
0.6127,
|
|
0.5781,
|
|
0.5905,
|
|
0.5633,
|
|
0.5585,
|
|
0.6539,
|
|
0.6334,
|
|
0.6003,
|
|
0.5772,
|
|
0.5347,
|
|
0.6061,
|
|
0.6419,
|
|
0.5479,
|
|
0.5582,
|
|
0.5404,
|
|
0.6531,
|
|
0.6028,
|
|
0.5482,
|
|
0.5579,
|
|
0.5644,
|
|
0.6064,
|
|
0.5913,
|
|
0.6302,
|
|
0.5631,
|
|
0.5461,
|
|
0.6551,
|
|
0.6142,
|
|
0.6295,
|
|
0.5712,
|
|
0.5677,
|
|
0.6012,
|
|
0.5998,
|
|
0.5688,
|
|
0.5585,
|
|
0.5643,
|
|
0.5889,
|
|
0.6405,
|
|
0.5609,
|
|
0.5574,
|
|
0.571,
|
|
0.616,
|
|
0.6381,
|
|
0.5958,
|
|
0.5904,
|
|
0.5562,
|
|
0.5759,
|
|
0.6378,
|
|
0.5804,
|
|
0.5568,
|
|
0.5411,
|
|
0.6559,
|
|
0.6074,
|
|
0.6196,
|
|
0.57,
|
|
0.5601,
|
|
0.6041,
|
|
0.6512,
|
|
0.6167,
|
|
0.5851,
|
|
0.532,
|
|
0.6477,
|
|
0.5868,
|
|
0.5786,
|
|
0.5452,
|
|
0.577,
|
|
0.5936,
|
|
0.6291,
|
|
0.6129,
|
|
0.5574,
|
|
0.5493,
|
|
0.5868,
|
|
0.6191,
|
|
0.5933,
|
|
0.6468,
|
|
0.5067,
|
|
0.6535,
|
|
0.6046,
|
|
0.5802,
|
|
0.5826,
|
|
0.552,
|
|
0.6254,
|
|
0.5682,
|
|
0.545,
|
|
0.5451,
|
|
0.5221,
|
|
0.6329,
|
|
0.5853,
|
|
0.6029,
|
|
0.5443,
|
|
0.5354,
|
|
0.6419,
|
|
0.6439,
|
|
0.5661,
|
|
0.5551,
|
|
0.5512,
|
|
0.6203,
|
|
0.6219,
|
|
0.6153,
|
|
0.5726,
|
|
0.5171,
|
|
0.5946,
|
|
0.6604,
|
|
0.6185,
|
|
0.5895,
|
|
0.5561,
|
|
0.5905,
|
|
0.5777,
|
|
0.6167,
|
|
0.546,
|
|
0.5482,
|
|
0.582,
|
|
0.5743,
|
|
0.6559,
|
|
0.5497,
|
|
0.5518,
|
|
0.5805,
|
|
0.6465,
|
|
0.5864,
|
|
0.5589,
|
|
0.5439,
|
|
0.6347,
|
|
0.6263,
|
|
0.5779,
|
|
0.5725,
|
|
0.5504,
|
|
0.6412,
|
|
0.6184,
|
|
0.6223,
|
|
0.5872,
|
|
0.5937,
|
|
0.6088,
|
|
0.5768,
|
|
0.5967,
|
|
0.6348,
|
|
0.5651,
|
|
0.6327,
|
|
0.6183,
|
|
0.5749,
|
|
0.6044,
|
|
0.5796,
|
|
0.6044,
|
|
0.6142,
|
|
0.6183,
|
|
0.5729,
|
|
0.5009,
|
|
0.5938,
|
|
0.6065,
|
|
0.5894,
|
|
0.5798,
|
|
0.5398,
|
|
0.6161,
|
|
0.6011,
|
|
0.6064,
|
|
0.6147,
|
|
0.5559,
|
|
0.6146,
|
|
0.5655,
|
|
0.5756,
|
|
0.6018,
|
|
0.5448,
|
|
0.6312,
|
|
0.6232,
|
|
0.5807,
|
|
0.5784,
|
|
0.5462,
|
|
0.6209,
|
|
0.5682,
|
|
0.6031,
|
|
0.5688,
|
|
0.5668,
|
|
0.6102,
|
|
0.6193,
|
|
0.5817,
|
|
0.5811,
|
|
0.5007,
|
|
0.6064,
|
|
0.5597,
|
|
0.5679,
|
|
0.5397,
|
|
0.5281,
|
|
0.5098,
|
|
0.5147,
|
|
0.5747,
|
|
0.5386,
|
|
0.5585,
|
|
0.474,
|
|
0.487,
|
|
0.5741,
|
|
0.5509,
|
|
0.5243,
|
|
0.5439,
|
|
0.5177,
|
|
0.5553,
|
|
0.5518,
|
|
0.5512,
|
|
0.5187,
|
|
0.491,
|
|
0.5827,
|
|
0.548,
|
|
0.5553,
|
|
0.491,
|
|
0.434,
|
|
0.5807,
|
|
0.5702,
|
|
0.6053,
|
|
0.4806,
|
|
0.4606,
|
|
0.607,
|
|
0.5538,
|
|
0.519,
|
|
0.5139,
|
|
0.5007,
|
|
0.5968,
|
|
0.5643,
|
|
0.5134,
|
|
0.4787,
|
|
0.4608,
|
|
0.5629,
|
|
0.5295,
|
|
0.5245,
|
|
0.5075,
|
|
0.4814,
|
|
0.5417,
|
|
0.5736,
|
|
0.5569,
|
|
0.4928,
|
|
0.5207,
|
|
0.5686,
|
|
0.5775,
|
|
0.5218,
|
|
0.4851,
|
|
0.507,
|
|
0.546,
|
|
0.5576,
|
|
0.5191,
|
|
0.4948,
|
|
0.5287,
|
|
0.5537,
|
|
0.5625,
|
|
0.5107,
|
|
0.5059,
|
|
0.4703,
|
|
0.6103,
|
|
0.5216,
|
|
0.5344,
|
|
0.4919,
|
|
0.4677,
|
|
0.5908,
|
|
0.5659,
|
|
0.5166,
|
|
0.519,
|
|
0.4767,
|
|
0.5625,
|
|
0.5085,
|
|
0.4887,
|
|
0.4936,
|
|
0.4947,
|
|
0.5443,
|
|
0.5458,
|
|
0.5185,
|
|
0.4895,
|
|
0.4643,
|
|
0.5534,
|
|
0.5632,
|
|
0.5568,
|
|
0.5118,
|
|
0.539,
|
|
0.516,
|
|
0.5417,
|
|
0.5192,
|
|
0.5115,
|
|
0.4897,
|
|
0.5493,
|
|
0.5564,
|
|
0.506,
|
|
0.4873,
|
|
0.5172,
|
|
0.5835,
|
|
0.5571,
|
|
0.5338,
|
|
0.5408,
|
|
0.4995,
|
|
0.5715,
|
|
0.551,
|
|
0.5058,
|
|
0.5434,
|
|
0.506,
|
|
0.5536,
|
|
0.5519,
|
|
0.5712,
|
|
0.4969,
|
|
0.4763,
|
|
0.5485,
|
|
0.5891,
|
|
0.5313,
|
|
0.5408,
|
|
0.4994,
|
|
0.6022,
|
|
0.5665,
|
|
0.5388,
|
|
0.474,
|
|
0.4552,
|
|
0.5447,
|
|
0.5727,
|
|
0.5203,
|
|
0.4823,
|
|
0.5249,
|
|
0.576,
|
|
0.5412,
|
|
0.5365,
|
|
0.493,
|
|
0.5027,
|
|
0.5552,
|
|
0.5302,
|
|
0.5154,
|
|
0.5185,
|
|
0.4982,
|
|
0.5412,
|
|
0.519,
|
|
0.5801,
|
|
0.5254,
|
|
0.4857,
|
|
0.5943,
|
|
0.5629,
|
|
0.5488,
|
|
0.4911,
|
|
0.5192,
|
|
0.5861,
|
|
0.5268,
|
|
0.511,
|
|
0.4939,
|
|
0.5551,
|
|
0.5396,
|
|
0.5397,
|
|
0.4844,
|
|
0.4749,
|
|
0.5745,
|
|
0.5412,
|
|
0.5219,
|
|
0.5113,
|
|
0.4973,
|
|
0.5877,
|
|
0.5216,
|
|
0.5343,
|
|
0.4973,
|
|
0.4757,
|
|
0.5476,
|
|
0.5714,
|
|
0.5668,
|
|
0.5235,
|
|
0.4618,
|
|
0.5758,
|
|
0.5278,
|
|
0.5091,
|
|
0.4877,
|
|
0.46,
|
|
0.571,
|
|
0.5575,
|
|
0.526,
|
|
0.5028,
|
|
0.4955,
|
|
0.5487,
|
|
]
|
|
return (loss_data,)
|
|
|
|
|
|
@app.cell
|
|
def __():
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import random
|
|
from matplotlib.colors import LinearSegmentedColormap
|
|
return LinearSegmentedColormap, np, plt, random
|
|
|
|
|
|
@app.cell
|
|
def __(plt):
|
|
plt.rcParams['font.family'] = 'Times New Roman'
|
|
plt.rcParams['font.size'] = 14
|
|
return
|
|
|
|
|
|
@app.cell
|
|
def __(loss_data, np, plt, random):
|
|
# 给定的 D 值列表
|
|
D_values = [4, 8]
|
|
# 创建一个图形和子图
|
|
fig, axs = plt.subplots(1, 2, figsize=(7, 1), dpi=400)
|
|
axs = axs.flatten() # 将 axs 数组展平,方便迭代
|
|
|
|
|
|
def smooth_curve(points, factor=0.8):
|
|
smoothed_points = []
|
|
for point in points:
|
|
if smoothed_points:
|
|
previous = smoothed_points[-1]
|
|
# 上一个节点*0.8+当前节点*0.2
|
|
smoothed_points.append(previous * factor + point * (1 - factor))
|
|
else:
|
|
# 添加point
|
|
smoothed_points.append(point)
|
|
return smoothed_points
|
|
|
|
|
|
c_1 = (230 / 255, 241 / 255, 243 / 255)
|
|
c_2 = (0, 0, 0)
|
|
c_3 = (255 / 255, 223 / 255, 146 / 255)
|
|
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
|
|
|
x = range(len(loss_data))
|
|
|
|
lorapp_loss = [x + random.uniform(-0.1, 0.1) for x in loss_data]
|
|
axs[1].plot(x, smooth_curve(loss_data), label="PEFT", color=c_2)
|
|
axs[1].plot(x, smooth_curve(lorapp_loss), label="mLoRA", color=c_4)
|
|
axs[1].set_xlabel("Training iteration", fontsize=14)
|
|
axs[1].set_ylabel("Loss", fontsize=14)
|
|
axs[1].set_ylim(0.0, 1.35)
|
|
axs[1].set_yticks(
|
|
[0.45, 0.9, 1.35],
|
|
["0.45", "0.9", "1.35"],
|
|
rotation=90,
|
|
ha="center",
|
|
va="center",
|
|
)
|
|
axs[1].set_xticks([0, 400, 800], ["0", "400", "800"], va="top")
|
|
|
|
axs[1].set_xlim(-100, 900)
|
|
axs[1].tick_params(pad=7)
|
|
|
|
axs[1].legend(ncol=1, fancybox=False, framealpha=0.0, fontsize=14)
|
|
|
|
|
|
x = [71, 63, 55, 47, 39, 31, 23, 15, 7, 3]
|
|
y = [
|
|
0.266739094408014,
|
|
0.23996592483352167,
|
|
0.2554257707424939,
|
|
0.22216522633727626,
|
|
0.24286307818113806,
|
|
0.21479247707365348,
|
|
0.202277902928929,
|
|
0.2755166593501294,
|
|
0.712087464890641,
|
|
0.9916578900340629,
|
|
]
|
|
|
|
c_1 = (230 / 255, 241 / 255, 243 / 255)
|
|
c_2 = (0, 0, 0)
|
|
c_3 = (255 / 255, 223 / 255, 146 / 255)
|
|
c_4 = (230 / 255, 109 / 255, 104 / 255)
|
|
|
|
# reverse the x-axis
|
|
x = x[::-1]
|
|
y = y[::-1]
|
|
y = np.array(y)
|
|
|
|
axs[0].plot(x, y, color=c_4)
|
|
axs[0].set_ylabel("MAPE (%)", fontsize=14)
|
|
axs[0].set_xlabel("Number of data points used for fitting ", fontsize=14)
|
|
axs[0].set_yticks(
|
|
[2, 1, 0], ["2", "1", "0"], rotation=90, ha="center", va="center"
|
|
)
|
|
axs[0].tick_params(pad=7)
|
|
|
|
axs[0].text(
|
|
0.5,
|
|
1.05,
|
|
"(a)",
|
|
fontsize=16,
|
|
va="bottom",
|
|
ha="right",
|
|
transform=axs[0].transAxes,
|
|
color="black",
|
|
)
|
|
axs[0].text(
|
|
0.5,
|
|
1.05,
|
|
"(b)",
|
|
fontsize=16,
|
|
va="bottom",
|
|
ha="right",
|
|
transform=axs[1].transAxes,
|
|
color="black",
|
|
)
|
|
|
|
#plt.savefig("map-and-loss.pdf", bbox_inches="tight", dpi=1000)
|
|
return (
|
|
D_values,
|
|
axs,
|
|
c_1,
|
|
c_2,
|
|
c_3,
|
|
c_4,
|
|
fig,
|
|
lorapp_loss,
|
|
smooth_curve,
|
|
x,
|
|
y,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|