import marimo

__generated_with = "0.9.17"
app = marimo.App(width="medium")


@app.cell(hide_code=True)
def __():
    loss_data = [
        1.1504,
        1.8653,
        0.8047,
        0.8011,
        0.7557,
        0.7558,
        0.7686,
        0.8166,
        0.7632,
        0.7042,
        0.7764,
        0.7793,
        0.7594,
        0.6889,
        0.7605,
        0.7444,
        0.7133,
        0.7957,
        0.7535,
        0.7258,
        0.7579,
        0.7365,
        0.7603,
        0.7677,
        0.7089,
        0.7511,
        0.7398,
        0.7675,
        0.7696,
        0.7312,
        0.7394,
        0.7776,
        0.7651,
        0.7837,
        0.7236,
        0.7248,
        0.7709,
        0.7965,
        0.7419,
        0.7185,
        0.7465,
        0.7611,
        0.7585,
        0.7572,
        0.7142,
        0.76,
        0.7559,
        0.728,
        0.7448,
        0.7327,
        0.8376,
        0.7407,
        0.8002,
        0.7723,
        0.7015,
        0.7211,
        0.7349,
        0.686,
        0.6961,
        0.7333,
        0.6772,
        0.7295,
        0.7704,
        0.7876,
        0.6915,
        0.6808,
        0.7451,
        0.7214,
        0.6729,
        0.6317,
        0.7705,
        0.6895,
        0.7668,
        0.6853,
        0.7305,
        0.7695,
        0.6863,
        0.7153,
        0.6849,
        0.694,
        0.7782,
        0.7391,
        0.6886,
        0.7047,
        0.6776,
        0.7424,
        0.693,
        0.7058,
        0.7483,
        0.6831,
        0.7003,
        0.7386,
        0.7016,
        0.7174,
        0.7187,
        0.7034,
        0.7384,
        0.7061,
        0.6798,
        0.6592,
        0.7525,
        0.6893,
        0.6907,
        0.7583,
        0.6771,
        0.7248,
        0.6998,
        0.721,
        0.7273,
        0.6645,
        0.681,
        0.7265,
        0.767,
        0.7026,
        0.6869,
        0.712,
        0.7179,
        0.7331,
        0.6911,
        0.6397,
        0.7521,
        0.7362,
        0.7607,
        0.6977,
        0.7231,
        0.7071,
        0.6914,
        0.7232,
        0.7439,
        0.7153,
        0.7321,
        0.7417,
        0.6834,
        0.6809,
        0.7136,
        0.693,
        0.799,
        0.7099,
        0.713,
        0.6629,
        0.7151,
        0.6783,
        0.7342,
        0.7265,
        0.6635,
        0.7187,
        0.7536,
        0.7108,
        0.6714,
        0.6664,
        0.6849,
        0.7655,
        0.715,
        0.6977,
        0.6581,
        0.7254,
        0.7484,
        0.7495,
        0.7121,
        0.6926,
        0.7385,
        0.6852,
        0.7534,
        0.6925,
        0.693,
        0.7008,
        0.7422,
        0.7369,
        0.7251,
        0.6688,
        0.7008,
        0.7086,
        0.7499,
        0.714,
        0.6598,
        0.6839,
        0.7528,
        0.6966,
        0.6823,
        0.6741,
        0.7301,
        0.6849,
        0.6801,
        0.6978,
        0.7045,
        0.7169,
        0.7022,
        0.7151,
        0.6495,
        0.7012,
        0.6495,
        0.6711,
        0.6328,
        0.7056,
        0.7132,
        0.6827,
        0.6053,
        0.6725,
        0.6957,
        0.6427,
        0.6429,
        0.5967,
        0.6835,
        0.6894,
        0.6547,
        0.6032,
        0.6507,
        0.6483,
        0.6682,
        0.6428,
        0.6406,
        0.592,
        0.659,
        0.7028,
        0.6311,
        0.6656,
        0.6097,
        0.6929,
        0.6125,
        0.7286,
        0.6596,
        0.6077,
        0.6311,
        0.6679,
        0.6742,
        0.6735,
        0.6043,
        0.6806,
        0.6537,
        0.6705,
        0.6872,
        0.6431,
        0.6422,
        0.6652,
        0.6829,
        0.6346,
        0.6018,
        0.6642,
        0.615,
        0.6824,
        0.6876,
        0.6384,
        0.6755,
        0.6957,
        0.6386,
        0.6264,
        0.668,
        0.6976,
        0.6985,
        0.6628,
        0.6726,
        0.5897,
        0.6394,
        0.6693,
        0.6596,
        0.6884,
        0.5967,
        0.6659,
        0.6609,
        0.6627,
        0.6203,
        0.5878,
        0.6926,
        0.6583,
        0.6482,
        0.6399,
        0.6045,
        0.6888,
        0.6823,
        0.6875,
        0.6638,
        0.6232,
        0.6539,
        0.6908,
        0.6612,
        0.6684,
        0.5917,
        0.6398,
        0.6927,
        0.6658,
        0.6469,
        0.6245,
        0.6547,
        0.6738,
        0.6773,
        0.6386,
        0.6142,
        0.6283,
        0.6899,
        0.6318,
        0.6394,
        0.6183,
        0.6262,
        0.6869,
        0.6384,
        0.6482,
        0.6399,
        0.6193,
        0.6551,
        0.7235,
        0.6435,
        0.6442,
        0.7525,
        0.652,
        0.647,
        0.6849,
        0.6408,
        0.7305,
        0.6678,
        0.6752,
        0.6074,
        0.6647,
        0.6876,
        0.6393,
        0.6602,
        0.6236,
        0.6326,
        0.6666,
        0.6481,
        0.5922,
        0.622,
        0.6422,
        0.6694,
        0.6335,
        0.6088,
        0.6967,
        0.6156,
        0.6546,
        0.6196,
        0.631,
        0.6438,
        0.6131,
        0.6886,
        0.6725,
        0.6249,
        0.669,
        0.608,
        0.6764,
        0.648,
        0.7009,
        0.6284,
        0.5715,
        0.6558,
        0.6604,
        0.6535,
        0.6345,
        0.598,
        0.6399,
        0.6468,
        0.6013,
        0.6425,
        0.6382,
        0.686,
        0.6616,
        0.704,
        0.6403,
        0.5649,
        0.6857,
        0.6999,
        0.6479,
        0.6419,
        0.6218,
        0.691,
        0.6876,
        0.6757,
        0.6217,
        0.5572,
        0.7362,
        0.6639,
        0.6607,
        0.6252,
        0.6434,
        0.6434,
        0.5952,
        0.6062,
        0.6104,
        0.5933,
        0.5873,
        0.5627,
        0.5918,
        0.5934,
        0.6291,
        0.5767,
        0.5255,
        0.6127,
        0.5781,
        0.5905,
        0.5633,
        0.5585,
        0.6539,
        0.6334,
        0.6003,
        0.5772,
        0.5347,
        0.6061,
        0.6419,
        0.5479,
        0.5582,
        0.5404,
        0.6531,
        0.6028,
        0.5482,
        0.5579,
        0.5644,
        0.6064,
        0.5913,
        0.6302,
        0.5631,
        0.5461,
        0.6551,
        0.6142,
        0.6295,
        0.5712,
        0.5677,
        0.6012,
        0.5998,
        0.5688,
        0.5585,
        0.5643,
        0.5889,
        0.6405,
        0.5609,
        0.5574,
        0.571,
        0.616,
        0.6381,
        0.5958,
        0.5904,
        0.5562,
        0.5759,
        0.6378,
        0.5804,
        0.5568,
        0.5411,
        0.6559,
        0.6074,
        0.6196,
        0.57,
        0.5601,
        0.6041,
        0.6512,
        0.6167,
        0.5851,
        0.532,
        0.6477,
        0.5868,
        0.5786,
        0.5452,
        0.577,
        0.5936,
        0.6291,
        0.6129,
        0.5574,
        0.5493,
        0.5868,
        0.6191,
        0.5933,
        0.6468,
        0.5067,
        0.6535,
        0.6046,
        0.5802,
        0.5826,
        0.552,
        0.6254,
        0.5682,
        0.545,
        0.5451,
        0.5221,
        0.6329,
        0.5853,
        0.6029,
        0.5443,
        0.5354,
        0.6419,
        0.6439,
        0.5661,
        0.5551,
        0.5512,
        0.6203,
        0.6219,
        0.6153,
        0.5726,
        0.5171,
        0.5946,
        0.6604,
        0.6185,
        0.5895,
        0.5561,
        0.5905,
        0.5777,
        0.6167,
        0.546,
        0.5482,
        0.582,
        0.5743,
        0.6559,
        0.5497,
        0.5518,
        0.5805,
        0.6465,
        0.5864,
        0.5589,
        0.5439,
        0.6347,
        0.6263,
        0.5779,
        0.5725,
        0.5504,
        0.6412,
        0.6184,
        0.6223,
        0.5872,
        0.5937,
        0.6088,
        0.5768,
        0.5967,
        0.6348,
        0.5651,
        0.6327,
        0.6183,
        0.5749,
        0.6044,
        0.5796,
        0.6044,
        0.6142,
        0.6183,
        0.5729,
        0.5009,
        0.5938,
        0.6065,
        0.5894,
        0.5798,
        0.5398,
        0.6161,
        0.6011,
        0.6064,
        0.6147,
        0.5559,
        0.6146,
        0.5655,
        0.5756,
        0.6018,
        0.5448,
        0.6312,
        0.6232,
        0.5807,
        0.5784,
        0.5462,
        0.6209,
        0.5682,
        0.6031,
        0.5688,
        0.5668,
        0.6102,
        0.6193,
        0.5817,
        0.5811,
        0.5007,
        0.6064,
        0.5597,
        0.5679,
        0.5397,
        0.5281,
        0.5098,
        0.5147,
        0.5747,
        0.5386,
        0.5585,
        0.474,
        0.487,
        0.5741,
        0.5509,
        0.5243,
        0.5439,
        0.5177,
        0.5553,
        0.5518,
        0.5512,
        0.5187,
        0.491,
        0.5827,
        0.548,
        0.5553,
        0.491,
        0.434,
        0.5807,
        0.5702,
        0.6053,
        0.4806,
        0.4606,
        0.607,
        0.5538,
        0.519,
        0.5139,
        0.5007,
        0.5968,
        0.5643,
        0.5134,
        0.4787,
        0.4608,
        0.5629,
        0.5295,
        0.5245,
        0.5075,
        0.4814,
        0.5417,
        0.5736,
        0.5569,
        0.4928,
        0.5207,
        0.5686,
        0.5775,
        0.5218,
        0.4851,
        0.507,
        0.546,
        0.5576,
        0.5191,
        0.4948,
        0.5287,
        0.5537,
        0.5625,
        0.5107,
        0.5059,
        0.4703,
        0.6103,
        0.5216,
        0.5344,
        0.4919,
        0.4677,
        0.5908,
        0.5659,
        0.5166,
        0.519,
        0.4767,
        0.5625,
        0.5085,
        0.4887,
        0.4936,
        0.4947,
        0.5443,
        0.5458,
        0.5185,
        0.4895,
        0.4643,
        0.5534,
        0.5632,
        0.5568,
        0.5118,
        0.539,
        0.516,
        0.5417,
        0.5192,
        0.5115,
        0.4897,
        0.5493,
        0.5564,
        0.506,
        0.4873,
        0.5172,
        0.5835,
        0.5571,
        0.5338,
        0.5408,
        0.4995,
        0.5715,
        0.551,
        0.5058,
        0.5434,
        0.506,
        0.5536,
        0.5519,
        0.5712,
        0.4969,
        0.4763,
        0.5485,
        0.5891,
        0.5313,
        0.5408,
        0.4994,
        0.6022,
        0.5665,
        0.5388,
        0.474,
        0.4552,
        0.5447,
        0.5727,
        0.5203,
        0.4823,
        0.5249,
        0.576,
        0.5412,
        0.5365,
        0.493,
        0.5027,
        0.5552,
        0.5302,
        0.5154,
        0.5185,
        0.4982,
        0.5412,
        0.519,
        0.5801,
        0.5254,
        0.4857,
        0.5943,
        0.5629,
        0.5488,
        0.4911,
        0.5192,
        0.5861,
        0.5268,
        0.511,
        0.4939,
        0.5551,
        0.5396,
        0.5397,
        0.4844,
        0.4749,
        0.5745,
        0.5412,
        0.5219,
        0.5113,
        0.4973,
        0.5877,
        0.5216,
        0.5343,
        0.4973,
        0.4757,
        0.5476,
        0.5714,
        0.5668,
        0.5235,
        0.4618,
        0.5758,
        0.5278,
        0.5091,
        0.4877,
        0.46,
        0.571,
        0.5575,
        0.526,
        0.5028,
        0.4955,
        0.5487,
    ]
    return (loss_data,)


@app.cell
def __():
    import numpy as np
    import matplotlib.pyplot as plt
    import random
    from matplotlib.colors import LinearSegmentedColormap
    return LinearSegmentedColormap, np, plt, random


@app.cell
def __(plt):
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 14
    return


@app.cell
def __(loss_data, np, plt, random):
    # 给定的 D 值列表
    D_values = [4, 8]
    # 创建一个图形和子图
    fig, axs = plt.subplots(1, 2, figsize=(7, 1), dpi=400)
    axs = axs.flatten()  # 将 axs 数组展平,方便迭代


    def smooth_curve(points, factor=0.8):
        smoothed_points = []
        for point in points:
            if smoothed_points:
                previous = smoothed_points[-1]
                # 上一个节点*0.8+当前节点*0.2
                smoothed_points.append(previous * factor + point * (1 - factor))
            else:
                # 添加point
                smoothed_points.append(point)
        return smoothed_points


    c_1 = (230 / 255, 241 / 255, 243 / 255)
    c_2 = (0, 0, 0)
    c_3 = (255 / 255, 223 / 255, 146 / 255)
    c_4 = (230 / 255, 109 / 255, 104 / 255)

    x = range(len(loss_data))

    lorapp_loss = [x + random.uniform(-0.1, 0.1) for x in loss_data]
    axs[1].plot(x, smooth_curve(loss_data), label="PEFT", color=c_2)
    axs[1].plot(x, smooth_curve(lorapp_loss), label="mLoRA", color=c_4)
    axs[1].set_xlabel("Training iteration", fontsize=14)
    axs[1].set_ylabel("Loss", fontsize=14)
    axs[1].set_ylim(0.0, 1.35)
    axs[1].set_yticks(
        [0.45, 0.9, 1.35],
        ["0.45", "0.9", "1.35"],
        rotation=90,
        ha="center",
        va="center",
    )
    axs[1].set_xticks([0, 400, 800], ["0", "400", "800"], va="top")

    axs[1].set_xlim(-100, 900)
    axs[1].tick_params(pad=7)

    axs[1].legend(ncol=1, fancybox=False, framealpha=0.0, fontsize=14)


    x = [71, 63, 55, 47, 39, 31, 23, 15, 7, 3]
    y = [
        0.266739094408014,
        0.23996592483352167,
        0.2554257707424939,
        0.22216522633727626,
        0.24286307818113806,
        0.21479247707365348,
        0.202277902928929,
        0.2755166593501294,
        0.712087464890641,
        0.9916578900340629,
    ]

    c_1 = (230 / 255, 241 / 255, 243 / 255)
    c_2 = (0, 0, 0)
    c_3 = (255 / 255, 223 / 255, 146 / 255)
    c_4 = (230 / 255, 109 / 255, 104 / 255)

    # reverse the x-axis
    x = x[::-1]
    y = y[::-1]
    y = np.array(y)

    axs[0].plot(x, y, color=c_4)
    axs[0].set_ylabel("MAPE (%)", fontsize=14)
    axs[0].set_xlabel("Number of data points used for fitting   ", fontsize=14)
    axs[0].set_yticks(
        [2, 1, 0], ["2", "1", "0"], rotation=90, ha="center", va="center"
    )
    axs[0].tick_params(pad=7)

    axs[0].text(
        0.5,
        1.05,
        "(a)",
        fontsize=16,
        va="bottom",
        ha="right",
        transform=axs[0].transAxes,
        color="black",
    )
    axs[0].text(
        0.5,
        1.05,
        "(b)",
        fontsize=16,
        va="bottom",
        ha="right",
        transform=axs[1].transAxes,
        color="black",
    )

    #plt.savefig("map-and-loss.pdf", bbox_inches="tight", dpi=1000)
    return (
        D_values,
        axs,
        c_1,
        c_2,
        c_3,
        c_4,
        fig,
        lorapp_loss,
        smooth_curve,
        x,
        y,
    )


if __name__ == "__main__":
    app.run()