{"id":708,"date":"2026-04-06T21:31:22","date_gmt":"2026-04-06T13:31:22","guid":{"rendered":"https:\/\/www.liaoxinghui.com\/?p=708"},"modified":"2026-04-06T21:31:22","modified_gmt":"2026-04-06T13:31:22","slug":"lora-finetune-gradient-vanishing-debug","status":"publish","type":"post","link":"https:\/\/www.liaoxinghui.com\/?p=708","title":{"rendered":"LoRA\u5fae\u8c03\u65f6\u68af\u5ea6\u6d88\u5931\u5bfc\u81f4\u8bad\u7ec3\u65e0\u6548\uff1a\u4ece\u65e5\u5fd7\u5f02\u5e38\u5230optimizer\u72b6\u6001\u5206\u6790"},"content":{"rendered":"<h2>\u4e1a\u52a1\u573a\u666f<\/h2>\n<p>\u6211\u4eec\u7528LoRA\u5fae\u8c03\u4e00\u4e2a7B\u7684\u4e2d\u6587\u5bf9\u8bdd\u6a21\u578b\u3002\u8dd1\u4e8612\u4e2a\u5c0f\u65f6\uff0closs\u4ece3.2\u964d\u52301.8\uff0c\u770b\u8d77\u6765\u5f88\u6b63\u5e38\u5bf9\u5427\uff1f\u4f46\u6211\u7b2c\u4e8c\u5929\u4e00\u770b\u9a8c\u8bc1\u96c6BLEU\u548cROUGE\uff0c\u548c\u6ca1\u8bad\u4e4b\u524d\u51e0\u4e4e\u4e00\u6a21\u4e00\u6837\u3002\u8fd9\u4e0d\u5bf9\u52b2\u3002<\/p>\n<p>\u4e1a\u52a1\u76ee\u6807\u662f\u8ba9\u6a21\u578b\u5728\u7279\u5b9a\u9886\u57df\uff08\u6bd4\u5982\u5ba2\u670d\u5bf9\u8bdd\uff09\u4e0a\u8868\u73b0\u66f4\u597d\u3002\u6211\u4eec\u6709\u4e2a\u5927\u69823000\u6761\u4eba\u5de5\u6807\u6ce8\u7684\u5bf9\u8bdd\u6570\u636e\uff0c\u90fd\u662f\u771f\u5b9e\u7528\u6237query\u548c\u6807\u51c6\u56de\u590d\u3002\u8bad\u7ec3\u73af\u5883\u662f\u4e00\u53f08\u5361A100\u673a\u5668\uff0c\u663e\u5b58\u7ef0\u7ef0\u6709\u4f59\uff0c\u6240\u4ee5\u6392\u9664\u4e86\u8d44\u6e90\u74f6\u9888\u3002<\/p>\n<p>\u6211\u7b2c\u4e00\u53cd\u5e94\u662f\u6570\u636e\u8d28\u91cf\u95ee\u9898\uff0c\u68c0\u67e5\u4e86\u4e24\u904d\u5bf9\u8bdd\u6570\u636e\u7684\u683c\u5f0f\u548c\u957f\u5ea6\u5206\u5e03\uff0c\u6ca1\u53d1\u73b0\u5f02\u5e38\u3002\u7136\u540e\u6000\u7591\u662fepoch\u4e0d\u591f\uff0c\u53c8\u8dd1\u4e866\u5c0f\u65f6\uff0closs\u7ee7\u7eed\u964d\u52301.2\uff0c\u9a8c\u8bc1\u6307\u6807\u8fd8\u662f\u7eb9\u4e1d\u4e0d\u52a8\u3002<\/p>\n<p>\u8fd9\u65f6\u5019\u6211\u624d\u53cd\u5e94\u8fc7\u6765\uff0c\u95ee\u9898\u53ef\u80fd\u4e0d\u5728\u6570\u636e\uff0c\u800c\u5728\u8bad\u7ec3\u8fc7\u7a0b\u672c\u8eab\u3002<\/p>\n<h2>\u95ee\u9898\u5b9a\u4f4d<\/h2>\n<p>\u5148\u8bf4\u7ed3\u8bba\uff1a\u6700\u540e\u5b9a\u4f4d\u5230\u662f\u68af\u5ea6\u6d88\u5931\u95ee\u9898\u3002<\/p>\n<p>\u5f53\u65f6\u6211\u6ca1\u6709\u5728\u8bad\u7ec3\u5faa\u73af\u91cc\u52a0\u68af\u5ea6\u7edf\u8ba1\uff0c\u53ea\u662f\u770bloss\u4e0b\u964d\u5c31\u4ee5\u4e3a\u4e00\u5207\u6b63\u5e38\u3002\u56de\u8fc7\u5934\u8865\u4e0a\u68af\u5ea6\u76d1\u63a7\u4e4b\u540e\uff0c\u53d1\u73b0\u95ee\u9898\u4e86\u2014\u2014<\/p>\n<pre><code class=\"lang-python language-python python\">import torch\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom peft import LoraConfig, get_peft_model, TaskType\n\n# \u8bad\u7ec3\u5faa\u73af\u91cc\u52a0\u8fd9\u4e2a\u56de\u8c03\ndef gradient_stats_hook(module, grad_input, grad_output):\n    &quot;&quot;&quot;\u6bcf\u5c42\u68af\u5ea6\u8303\u6570\u7edf\u8ba1&quot;&quot;&quot;\n    if grad_output[0] is not None:\n        grad_norm = grad_output[0].norm().item()\n        module._grad_norm = grad_norm\n\n# \u6ce8\u518chook\u5230\u6240\u6709LoRA\u5c42\nmodel = AutoModelForCausalLM.from_pretrained(\n    &quot;meta-llama\/Llama-2-7b-hf&quot;,\n    device_map=&quot;auto&quot;,\n    torch_dtype=torch.float16\n)\n\nlora_config = LoraConfig(\n    task_type=TaskType.CAUSAL_LM,\n    r=16,\n    lora_alpha=32,\n    lora_dropout=0.05,\n    target_modules=[&quot;q_proj&quot;, &quot;v_proj&quot;, &quot;k_proj&quot;, &quot;o_proj&quot;]\n)\n\nmodel = get_peft_model(model, lora_config)\nmodel.print_trainable_parameters()\n\n# \u7ed9\u6240\u6709\u53ef\u8bad\u7ec3\u53c2\u6570\u6302hook\nfor name, param in model.named_parameters():\n    if param.requires_grad:\n        param.register_hook(lambda grad, n=name: \n            print(f&quot;[{n}] grad_norm={grad.norm().item():.6f}&quot;)\n        )<\/code><\/pre>\n<p>\u8dd1\u4e86\u4e00\u4e2astep\uff0c\u8f93\u51fa\u628a\u6211\u5413\u5230\u4e86\uff1a<\/p>\n<pre><code>[lora_B.weight] grad_norm=0.0234\n[lora_A.weight] grad_norm=0.0001\n[model.layers.0.self_attn.q_proj.lora_B.weight] grad_norm=0.0218\n[model.layers.0.self_attn.q_proj.lora_A.weight] grad_norm=0.00008\n...\n[model.layers.30.self_attn.q_proj.lora_B.weight] grad_norm=0.0192\n[model.layers.30.self_attn.q_proj.lora_A.weight] grad_norm=0.00002<\/code><\/pre>\n<p>\u770b\u5230\u4e86\u5417\uff1f\u6240\u6709\u5c42\u7684<code>lora_B.weight<\/code>\u68af\u5ea6\u8303\u6570\u57280.02\u5de6\u53f3\uff0c\u4f46<code>lora_A.weight<\/code>\u68af\u5ea6\u53ea\u67090.0001\u5de6\u53f3\uff0c\u5dee\u4e86200\u500d\u3002<\/p>\n<h2>\u6570\u636e\u8bf4\u660e<\/h2>\n<h3>\u8bad\u7ec3\u6570\u636e\u6765\u6e90<\/h3>\n<p>\u6211\u4eec\u7684\u8bad\u7ec3\u6570\u636e\u662f\u5185\u90e8\u79ef\u7d2f\u7684\u5ba2\u670d\u5bf9\u8bdd\uff0c\u7528<code>transformers<\/code>\u7684\u5bf9\u8bdd\u6a21\u677f\u5904\u7406\u8fc7\u3002\u5927\u6982\u957f\u8fd9\u6837\uff1a<\/p>\n<pre><code class=\"lang-json language-json json\">{\n  &quot;conversations&quot;: [\n    {&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: &quot;\u6211\u7684\u8ba2\u5355\u53f7\u662f20240115\uff0c\u4ec0\u4e48\u65f6\u5019\u53d1\u8d27\uff1f&quot;},\n    {&quot;role&quot;: &quot;assistant&quot;, &quot;content&quot;: &quot;\u60a8\u597d\uff01\u67e5\u8be2\u5230\u60a8\u7684\u8ba2\u5355\uff0c\u9884\u8ba148\u5c0f\u65f6\u5185\u53d1\u8d27\u3002&quot;}\n  ]\n}<\/code><\/pre>\n<p>\u6570\u636e\u91cf3000\u6761\uff0c\u5e73\u5747\u6bcf\u6761\u5bf9\u8bdd\u957f\u5ea6\u7ea6512 tokens\u3002\u7528\u7684\u662fChatML\u6a21\u677f\uff0c\u5728user\u548cassistant\u6d88\u606f\u524d\u540e\u52a0\u4e86\u7279\u6b8a\u6807\u8bb0\u3002<\/p>\n<h3>\u6570\u636e\u9884\u5904\u7406<\/h3>\n<p>\u6570\u636e\u52a0\u8f7d\u65f6\u505a\u4e86\u8fd9\u51e0\u4ef6\u4e8b\uff1a<\/p>\n<pre><code class=\"lang-python language-python python\">from transformers import AutoTokenizer\n\ntokenizer = AutoTokenizer.from_pretrained(&quot;meta-llama\/Llama-2-7b-hf&quot;, trust_remote_code=True)\n\ndef preprocess_function(examples):\n    # \u62fc\u63a5\u5bf9\u8bdd\u5386\u53f2\uff0c\u52a0\u4e0a\u7279\u6b8atoken\n    texts = []\n    for conv in examples[&quot;conversations&quot;]:\n        text = &quot;&quot;&lt;|im_start|&gt;user\\n&quot; + conv[0][&quot;content&quot;]\n        text += &quot;&lt;|im_end|&gt;&lt;|im_start|&gt;assistant\\n&quot; + conv[1][&quot;content&quot;]\n        text += &quot;&lt;|im_end|&gt;&quot;\n        texts.append(text)\n\n    # tokenize\uff0clabels\u62f7\u8d1dinput_ids\n    model_inputs = tokenizer(texts, max_length=512, truncation=True, padding=&quot;max_length&quot;)\n    model_inputs[&quot;labels&quot;] = model_inputs[&quot;input_ids&quot;].copy()\n    return model_inputs<\/code><\/pre>\n<p>\u9884\u5904\u7406\u8fd9\u5757\u6211\u6ca1\u53d1\u73b0\u95ee\u9898\uff0c\u540e\u6765\u5b9a\u4f4d\u5230\u95ee\u9898\u5728\u6a21\u578b\u672c\u8eab\u7684\u8bad\u7ec3\u903b\u8f91\u3002<\/p>\n<h3>\u6570\u636e\u5206\u5e03\u68c0\u67e5\uff08\u5bb9\u6613\u8e29\u7684\u5751\uff09<\/h3>\n<p>\u6211\u540e\u6765\u53c8\u9047\u5230\u8fc7\u4e00\u6b21\u7c7b\u4f3c\u95ee\u9898\uff0c\u8fd9\u6b21\u4e0d\u662f\u521d\u59cb\u5316\u95ee\u9898\uff0c\u800c\u662f\u6570\u636e\u95ee\u9898\u3002<\/p>\n<p>\u6570\u636e\u96c6\u91cc\u67d0\u4e2a\u7279\u6b8atoken\u5360\u6bd4\u8d85\u8fc740%\uff0c\u5bfc\u81f4\u6a21\u578b\u53ea\u8981\u9884\u6d4b\u8fd9\u4e2atoken\u5c31\u80fd\u628aloss\u62c9\u4f4e\uff0c\u4f46\u5b9e\u9645\u4e0a\u4ec0\u4e48\u90fd\u6ca1\u5b66\u5230\u3002\u8fd9\u79cd\u60c5\u51b5\u4e0b\u68af\u5ea6\u6d88\u5931\u662f\u8868\u8c61\uff0c\u771f\u6b63\u7684\u95ee\u9898\u662f\u6570\u636e\u4e0d\u5e73\u8861\u3002<\/p>\n<p>\u6392\u67e5\u65b9\u6cd5\uff1a<\/p>\n<pre><code class=\"lang-python language-python python\">from collections import Counter\nimport torch\nfrom transformers import AutoTokenizer\n\ntokenizer = AutoTokenizer.from_pretrained(&quot;meta-llama\/Llama-2-7b-hf&quot;)\n\n# \u7edf\u8ba1token\u5206\u5e03\nall_token_ids = []\nfor sample in dataset:\n    tokens = tokenizer.encode(sample[&#039;text&#039;])\n    all_token_ids.extend(tokens)\n\ntoken_counts = Counter(all_token_ids)\ntotal = sum(token_counts.values())\n\n# \u6253\u5370\u6700\u5e38\u89c1\u768420\u4e2atoken\nprint(&quot;Top 20 tokens:&quot;)\nfor token_id, count in token_counts.most_common(20):\n    token_str = tokenizer.decode([token_id])\n    pct = count \/ total * 100\n    print(f&quot;  [{token_id:5d}] &#039;{token_str}&#039; -&gt; {pct:.2f}%&quot;)\n\n# \u68c0\u67e5\u7279\u6b8atoken\u6bd4\u4f8b\nspecial_tokens = tokenizer.all_special_ids\nspecial_ratio = sum(token_counts[t] for t in special_tokens) \/ total\nprint(f&quot;\\nSpecial tokens ratio: {special_ratio:.2%}&quot;)\n\n# \u5982\u679c\u8d85\u8fc75%\uff0c\u8bf4\u660e\u6709\u95ee\u9898\nif special_ratio &gt; 0.05:\n    print(&quot;\u26a0\ufe0f WARNING: Special token ratio too high, may cause gradient issues&quot;)<\/code><\/pre>\n<h2>\u6839\u56e0\u5206\u6790<\/h2>\n<h3>LoRA\u68af\u5ea6\u6d88\u5931\u7684\u539f\u7406<\/h3>\n<p>\u8fd9\u91cc\u5f97\u89e3\u91ca\u4e0bLoRA\u7684\u7ed3\u6784\uff0c\u4e0d\u7136\u8bf4\u4e0d\u6e05\u695a\u4e3a\u4ec0\u4e48\u4f1a\u51fa\u73b0\u8fd9\u4e2a\u95ee\u9898\u3002<\/p>\n<p>LoRA\u7684\u6838\u5fc3\u662f\u4f4e\u79e9\u5206\u89e3\u3002\u5bf9\u4e8e\u4e00\u4e2a\u539f\u59cb\u6743\u91cd $W \\in \\mathbb{R}^{d \\times k}$\uff0cLoRA\u5f15\u5165\u4e24\u4e2a\u77e9\u9635\uff1a<\/p>\n<ul>\n<li>$A \\in \\mathbb{R}^{r \\times k}$\uff0c\u7528\u968f\u673a\u521d\u59cb\u5316<\/li>\n<li>$B \\in \\mathbb{R}^{d \\times r}$\uff0c\u521d\u59cb\u5316\u4e3a\u96f6<\/li>\n<\/ul>\n<p>\u524d\u5411\u8ba1\u7b97\u53d8\u6210\uff1a<\/p>\n<p>$$h = W \\cdot x + B \\cdot A \\cdot x$$<\/p>\n<p>\u95ee\u9898\u51fa\u5728\u521d\u59cb\u5316\u4e0a\u3002$B$\u521d\u59cb\u5316\u4e3a\u96f6\uff0c\u610f\u5473\u7740\u8bad\u7ec3\u5f00\u59cb\u65f6 $B \\cdot A = 0$\uff0c\u65b0\u589e\u7684\u8def\u5f84\u5bf9\u8f93\u51fa\u6ca1\u6709\u4efb\u4f55\u8d21\u732e\u3002\u7136\u540e\u68af\u5ea6\u4ece $h$ \u56de\u4f20\u5230 $B$ \u548c $A$\uff0c\u4f46 $A$ \u7684\u68af\u5ea6\u8981\u7ecf\u8fc7 $B$ \u624d\u80fd\u5f71\u54cd\u8f93\u51fa\u2014\u2014\u8fd9\u5f62\u6210\u4e00\u4e2a\u94fe\u5f0f\u4f9d\u8d56\u3002<\/p>\n<p>\u5177\u4f53\u6765\u8bf4\uff0c$B$ \u7684\u68af\u5ea6\u662f\uff1a<\/p>\n<p>$$\\frac{\\partial L}{\\partial B} = \\frac{\\partial L}{\\partial h} \\cdot (A \\cdot x)^T$$<\/p>\n<p>\u800c $A$ \u7684\u68af\u5ea6\u662f\uff1a<\/p>\n<p>$$\\frac{\\partial L}{\\partial A} = B^T \\cdot \\frac{\\partial L}{\\partial h} \\cdot x^T$$<\/p>\n<p>\u8bad\u7ec3\u521a\u5f00\u59cb\u65f6 $B = 0$\uff0c\u6240\u4ee5 $\\frac{\\partial L}{\\partial A} \\approx 0$\uff0c$A$ \u51e0\u4e4e\u6536\u4e0d\u5230\u68af\u5ea6\u3002\u800c $B$ \u7684\u68af\u5ea6\u867d\u7136\u6709\uff0c\u4f46\u91cf\u7ea7\u5f88\u5c0f\uff0c\u56e0\u4e3a $A \\cdot x$ \u7684\u521d\u59cb\u8f93\u51fa\u4e5f\u5728\u96f6\u70b9\u9644\u8fd1\u3002<\/p>\n<p>\u8fd9\u5c31\u662f\u6240\u8c13\u7684&#8221;\u68af\u5ea6\u6d88\u5931&#8221;\u2014\u2014\u4e0d\u662f\u771f\u7684\u6d88\u5931\uff0c\u800c\u662f\u521d\u59cb\u5316\u5bfc\u81f4 $A$ \u88ab\u9501\u6b7b\u4e86\uff0c\u53ea\u6709 $B$ \u5728\u5fae\u5f31\u5730\u66f4\u65b0\u3002<\/p>\n<h3>\u7528optimizer\u72b6\u6001\u9a8c\u8bc1<\/h3>\n<p>\u68af\u5ea6\u76d1\u63a7\u786e\u8ba4\u4e86\u95ee\u9898\uff0c\u4f46\u6211\u8fd8\u60f3\u770b\u770boptimizer\u5b9e\u9645\u5728\u505a\u4ec0\u4e48\u3002<\/p>\n<pre><code class=\"lang-python language-python python\"># \u8bad\u7ec3\u51e0\u4e2astep\u540e\u68c0\u67e5optimizer\u72b6\u6001\noptimizer = torch.optim.AdamW(\n    model.parameters(),\n    lr=2e-4,\n    betas=(0.9, 0.999),\n    weight_decay=0.01\n)\n\n# \u6a21\u62df\u51e0\u6b65\u8bad\u7ec3\nfor batch in dataloader:\n    outputs = model(**batch)\n    loss = outputs.loss\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    break  # \u770b\u7b2c\u4e00\u6b65\u4e4b\u540e\u7684\u72b6\u6001\n\n# \u68c0\u67e5exp_avg\u548cexp_avg_sq\nprint(&quot;=== Optimizer State Analysis ===&quot;)\nfor name, param in model.named_parameters():\n    if not param.requires_grad:\n        continue\n\n    state = optimizer.state.get(param)\n    if state is None:\n        continue\n\n    exp_avg = state[&#039;exp_avg&#039;]\n    exp_avg_sq = state[&#039;exp_avg_sq&#039;]\n\n    print(f&quot;\\n{name}:&quot;)\n    print(f&quot;  param mean: {param.data.mean():.6f}, std: {param.data.std():.6f}&quot;)\n    print(f&quot;  exp_avg mean: {exp_avg.mean():.8f}, std: {exp_avg.std():.8f}&quot;)\n    print(f&quot;  exp_avg_sq mean: {exp_avg_sq.mean():.8f}&quot;)\n    print(f&quot;  update\/param ratio: {(exp_avg \/ (exp_avg_sq.sqrt() + 1e-8)).mean():.8f}&quot;)<\/code><\/pre>\n<p>\u8f93\u51fa\u5927\u6982\u662f\u8fd9\u6837\uff1a<\/p>\n<pre><code>=== Optimizer State Analysis ===\n\nlora_B.weight:\n  param mean: 0.002341, std: 0.015623\n  exp_avg mean: 0.000023, std: 0.000156\n  exp_avg_sq mean: 0.000061\n  update\/param ratio: 0.029412\n\nlora_A.weight:\n  param mean: 0.031256, std: 0.100523\n  exp_avg mean: 0.000001, std: 0.000005\n  exp_avg_sq mean: 0.000000\n  update\/param ratio: 0.000823<\/code><\/pre>\n<p>lora_A\u7684exp_avg_sq\u51e0\u4e4e\u662f\u96f6\uff0c\u8bf4\u660e\u5b83\u51e0\u4e4e\u6ca1\u6536\u5230\u8fc7\u6709\u6548\u7684\u68af\u5ea6\u3002\u800clora_B\u867d\u7136\u72b6\u6001\u4e5f\u4e0d\u6d3b\u8dc3\uff0c\u4f46\u6bd4A\u597d\u592a\u591a\u4e86\u3002<\/p>\n<h2>\u53c2\u6570\u8bf4\u660e<\/h2>\n<h3>rank\u548calpha\u7684\u53d6\u503c\u4f9d\u636e<\/h3>\n<p>\u6211\u7528\u7684\u662f<code>r=16, lora_alpha=32<\/code>\uff0c\u8fd9\u4e24\u4e2a\u6570\u662f\u600e\u4e48\u5b9a\u7684\uff1f<\/p>\n<p><strong>rank\u7684\u9009\u62e9<\/strong>\uff1a<\/p>\n<p>rank\u51b3\u5b9a\u4e86\u4f4e\u79e9\u77e9\u9635\u7684\u79e9\uff0c\u4e5f\u5373LoRA\u5c42\u80fd\u8868\u8fbe\u7684\u7ebf\u6027\u7a7a\u95f4\u7ef4\u5ea6\u3002\u6211\u5f53\u65f6\u8003\u8651\u7684\u662f\uff1a<\/p>\n<ul>\n<li>\n<p>rank=8\uff1a\u592a\u4fdd\u5b88\u4e86\u30027B\u6a21\u578b\u67094096\u7684hidden_size\uff0crank=8\u610f\u5473\u7740\u6bcf\u4e2aLoRA\u5c42\u53ea\u670964\u7ef4\u7684\u8868\u793a\u80fd\u529b\u3002\u5bf9\u4e8e\u5ba2\u670d\u5bf9\u8bdd\u8fd9\u79cd\u9700\u8981\u6355\u6349\u8bed\u4e49\u7ec6\u5fae\u5dee\u522b\u7684\u4efb\u52a1\uff0c8\u53ef\u80fd\u4e0d\u591f\u3002<\/p>\n<\/li>\n<li>\n<p>rank=32\uff1a\u7406\u8bba\u4e0a\u8868\u793a\u80fd\u529b\u66f4\u5f3a\uff0c\u4f46\u663e\u5b58\u5360\u7528\u4e5f\u4f1a\u7ffb\u500d\u30023000\u6761\u6570\u636e\u7684\u5c0f\u4efb\u52a1\uff0c\u752832\u6709\u70b9 overkill\u3002<\/p>\n<\/li>\n<li>\n<p>rank=16\uff1a\u6298\u4e2d\u65b9\u6848\u3002\u5b9e\u6d4b\u4e0b\u6765\uff0crank=16\u5728\u5927\u591a\u6570\u5782\u57df\u5fae\u8c03\u573a\u666f\u591f\u7528\u4e86\u3002\u9664\u975e\u4f60\u7684\u4efb\u52a1\u9700\u8981\u975e\u5e38\u7ec6\u7c92\u5ea6\u7684\u6307\u4ee4\u9075\u5faa\uff08\u6bd4\u5982\u590d\u6742\u7684\u6570\u5b66\u63a8\u7406\uff09\uff0c\u90a3\u65f6\u5019\u518d\u8003\u865132\u621664\u3002<\/p>\n<\/li>\n<\/ul>\n<p><strong>alpha\u7684\u9009\u62e9<\/strong>\uff1a<\/p>\n<p>alpha\u662f\u7f29\u653e\u56e0\u5b50\uff0c\u6700\u7ec8scale = alpha \/ rank\u3002alpha=32\u610f\u5473\u7740scale=2\uff0c\u4e5f\u5c31\u662fLoRA\u5206\u652f\u7684\u8f93\u51fa\u6743\u91cd\u662f2\u500d\u3002<\/p>\n<p>\u516c\u5f0f\u4e0a\uff0c\u6700\u7ec8\u8f93\u51fa\u662f <code>h = Wx + (BAx) * (alpha\/r)<\/code>\u3002\u5982\u679calpha\u592a\u5c0f\uff0cLoRA\u7684\u8d21\u732e\u88ab\u538b\u5f97\u5f88\u4f4e\uff1b\u5982\u679calpha\u592a\u5927\uff0c\u53ef\u80fd\u7834\u574f\u9884\u8bad\u7ec3\u5b66\u5230\u7684\u77e5\u8bc6\u3002<\/p>\n<p>\u6211\u4e00\u822calpha\u8bbe\u4e3arank\u76842\u500d\uff0c\u4e5f\u5c31\u662f\u56fa\u5b9ascale=2\u3002\u5982\u679c\u8bad\u7ec3\u4e0d\u7a33\u5b9a\uff08loss\u7206\u70b8\uff09\uff0c\u4f1a\u964d\u4f4ealpha\u52301.5\u500d\u62161\u500d\u3002<\/p>\n<h3>target_modules\u7684\u9009\u62e9<\/h3>\n<pre><code class=\"lang-python language-python python\">target_modules=[&quot;q_proj&quot;, &quot;v_proj&quot;, &quot;k_proj&quot;, &quot;o_proj&quot;]<\/code><\/pre>\n<p>q\u662fquery\uff0ck\u662fkey\uff0cv\u662fvalue\uff0co\u662fattention\u8f93\u51fa\u3002\u6211\u5f53\u65f6\u5168\u9009\u4e86\u3002<\/p>\n<p>\u5176\u5b9e\u5982\u679c\u663e\u5b58\u7d27\u5f20\uff0c\u53ef\u4ee5\u53ea\u6253q\u548cv\u3002o\u4e00\u822c\u5f71\u54cd\u4e0d\u5927\uff0c\u56e0\u4e3a\u5b83\u662f\u8f93\u51fa\u6295\u5f71\uff0c\u5df2\u7ecf\u88ab\u4e0b\u6e38\u5c42\u5904\u7406\u8fc7\u4e86\u3002<\/p>\n<p>\u8fd8\u6709\u4e2a\u9009\u62e9\u662f\u52a0\u4e0affn\u5c42\u7684gate\u548cup_proj\uff0c\u4f46\u6211\u6ca1\u52a0\uff0c\u6015\u5f15\u5165\u592a\u591a\u53c2\u6570\u5bfc\u81f4\u8fc7\u62df\u5408\u3002<\/p>\n<h2>\u89e3\u51b3\u65b9\u6848<\/h2>\n<p>\u627e\u5230\u6839\u56e0\u5c31\u597d\u529e\u4e86\u3002\u6211\u8bd5\u4e86\u56db\u79cd\u65b9\u6848\uff0c\u6700\u540e\u9009\u4e86\u65b9\u68484\u3002<\/p>\n<h3>\u65b9\u68481\uff1a\u6539A\u7684\u521d\u59cb\u5316<\/h3>\n<p>\u6807\u51c6LoRA\u7528\u968f\u673a\u521d\u59cb\u5316A\u3001\u9ad8\u65af\u521d\u59cb\u5316B\u4e3a\u96f6\u3002\u66f4\u5408\u7406\u7684\u505a\u6cd5\u662f\u8ba9A\u7684\u521d\u59cb\u5316\u65b9\u5dee\u5c0f\u4e00\u4e9b\uff0c\u8fd9\u6837 $A \\cdot x$ \u7684\u8f93\u51fa\u4e0d\u4f1a\u4e00\u5f00\u59cb\u5c31\u628aB\u7684\u8d21\u732e\u6df9\u6ca1\u3002<\/p>\n<pre><code class=\"lang-python language-python python\">from peft.tuners.lora import LoraLayer\n\n# \u81ea\u5b9a\u4e49LoRA\u5c42\u521d\u59cb\u5316\nclass CustomLoraLayer(LoraLayer):\n    def reset_lora_parameters(self):\n        # B\u4fdd\u6301\u4e3a\u96f6\u521d\u59cb\u5316\n        if hasattr(self, &#039;lora_B&#039;) and self.lora_B is not None:\n            self.lora_B.zero_()\n        # A\u6539\u7528\u66f4\u5c0f\u7684\u65b9\u5dee\n        if hasattr(self, &#039;lora_A&#039;) and self.lora_A is not None:\n            torch.nn.init.normal_(self.lora_A.weight, std=0.01)<\/code><\/pre>\n<p>\u4f46\u5b9e\u6d4b\u4e0b\u6765\u8fd9\u4e2a\u6539\u52a8\u6548\u679c\u6709\u9650\u3002\u95ee\u9898\u5728\u4e8e\uff0c\u53ea\u6539A\u7684\u65b9\u5dee\u5e76\u4e0d\u80fd\u89e3\u51b3\u94fe\u5f0f\u4f9d\u8d56\u2014\u2014A\u7684\u68af\u5ea6\u8fd8\u662f\u8981\u7ecf\u8fc7B\u624d\u80fd\u4f20\u5bfc\u56de\u53bb\uff0c\u65b9\u5dee\u5c0f\u53ea\u662f\u8ba9\u68af\u5ea6\u91cf\u7ea7\u5c0f\u4e00\u70b9\uff0c\u5e76\u6ca1\u6709\u6253\u7834\u6b7b\u9501\u3002<\/p>\n<h3>\u65b9\u68482\uff1a\u589e\u5927alpha\/r\u6bd4\u4f8b<\/h3>\n<p>LoRA\u6709\u4e2ascale\u53c2\u6570 $\\alpha \/ r$\u3002\u5982\u679c\u5f53\u524d r=16, alpha=32\uff0cscale=2\u3002\u6211\u8bd5\u8fc7\u628aalpha\u8c03\u523064\uff1a<\/p>\n<pre><code class=\"lang-python language-python python\">lora_config = LoraConfig(\n    r=16,\n    lora_alpha=64,  # \u539f\u6765\u662f32\n    ...\n)<\/code><\/pre>\n<p>\u8fd9\u8ba9\u6700\u7ec8\u8f93\u51fa\u91cc $B \\cdot A \\cdot x$ \u7684\u6743\u91cd\u53d8\u5927\u4e86\uff0c\u4f46\u6cbb\u6807\u4e0d\u6cbb\u672c\u2014\u2014\u68af\u5ea6\u4f9d\u7136\u53ea\u6d41\u5411B\uff0cA\u8fd8\u662f\u534a\u6b7b\u4e0d\u6d3b\u3002\u800c\u4e14alpha\u592a\u5927\u8fd8\u6709\u4e2a\u95ee\u9898\uff1a\u53ef\u80fd\u7834\u574f\u9884\u8bad\u7ec3\u6743\u91cd\u5df2\u7ecf\u5b66\u597d\u7684\u8868\u793a\u3002<\/p>\n<h3>\u65b9\u68483\uff1a\u5f52\u4e00\u5316\u68af\u5ea6<\/h3>\n<p>\u6700\u7ec8\u65b9\u6848\u662f\u628a\u6240\u6709LoRA\u53c2\u6570\u7684\u68af\u5ea6\u505a\u5f52\u4e00\u5316\uff0c\u8ba9A\u548cB\u6536\u5230\u76f8\u5bf9\u5747\u8861\u7684\u68af\u5ea6\uff1a<\/p>\n<pre><code class=\"lang-python language-python python\">class GradientNormalizedLoRA(torch.nn.Module):\n    def __init__(self, base_layer, rank=16, alpha=32):\n        super().__init__()\n        self.base_layer = base_layer\n        self.rank = rank\n        self.scale = alpha \/ rank\n\n        # A\u548cB\u5206\u522b\u5b58\uff0c\u65b9\u4fbf\u5355\u72ec\u5904\u7406\n        self.lora_A = torch.nn.Parameter(torch.randn(rank, base_layer.in_features) * 0.01)\n        self.lora_B = torch.nn.Parameter(torch.zeros(base_layer.out_features, rank))\n\n        self.base_layer.weight.requires_grad = False\n\n    def forward(self, x):\n        # \u539f\u59cb\u524d\u5411\n        base_out = self.base_layer(x)\n        # LoRA\u8def\u5f84\n        lora_out = (self.lora_B @ (self.lora_A @ x.T)).T\n        return base_out + lora_out * self.scale\n\ndef gradient_balance_hook(model):\n    &quot;&quot;&quot;\u628a\u6240\u6709LoRA\u5bf9\u7684\u68af\u5ea6\u505a\u5747\u8861&quot;&quot;&quot;\n    # \u6309\u5c42\u904d\u5386\uff0c\u627e\u5230lora_A\u548clora_B\u914d\u5bf9\n    for name, module in model.named_modules():\n        if isinstance(module, GradientNormalizedLoRA):\n            # \u68af\u5ea6\u88c1\u526a\u548c\u5747\u8861\n            grad_A = module.lora_A.grad\n            grad_B = module.lora_B.grad\n\n            if grad_A is not None and grad_B is not None:\n                # \u8ba1\u7b97\u68af\u5ea6\u8303\u6570\u6bd4\n                norm_A = grad_A.norm()\n                norm_B = grad_B.norm()\n\n                if norm_A &gt; 1e-8 and norm_B &gt; 1e-8:\n                    # \u628aB\u7684\u68af\u5ea6\u6309\u6bd4\u4f8b\u653e\u5927\uff0c\u8ba9\u5b83\u4fe9\u91cf\u7ea7\u63a5\u8fd1\n                    ratio = norm_B \/ norm_A\n                    if ratio &lt; 0.1:  # A\u8fdc\u5927\u4e8eB\n                        module.lora_B.grad = grad_B * (norm_A \/ norm_B) * 0.5\n                    elif ratio &gt; 10:  # B\u8fdc\u5927\u4e8eA\n                        module.lora_A.grad = grad_A * (norm_B \/ norm_A) * 0.5<\/code><\/pre>\n<p>\u4f46\u8fd9\u4e2a\u65b9\u6848\u6709\u4e2a\u95ee\u9898\u2014\u2014\u5728\u8bad\u7ec3\u5faa\u73af\u91cc\u624b\u52a8\u6539\u68af\u5ea6\u4f1a\u589e\u52a0\u7ea615%\u7684\u8ba1\u7b97\u5f00\u9500\uff0c\u800c\u4e14\u903b\u8f91\u6bd4\u8f83hacky\uff0c\u4e0d\u9002\u5408\u751f\u4ea7\u73af\u5883\u3002<\/p>\n<h3>\u65b9\u68484\uff1a\u6b63\u4ea4\u521d\u59cb\u5316\uff08\u6211\u7684\u9009\u62e9\uff09<\/h3>\n<p>\u7efc\u5408\u8003\u8651\u5b9e\u73b0\u6210\u672c\u548c\u6548\u679c\uff0c\u6211\u6700\u540e\u9009\u4e86\u65b9\u68484\uff0c\u539f\u56e0\u662f\u5b83\u4e0d\u9700\u8981\u6539\u8bad\u7ec3\u5faa\u73af\uff0c\u7eaf\u9760\u6539\u521d\u59cb\u5316\u5c31\u80fd\u89e3\u51b3\u95ee\u9898\uff0c\u4ee3\u7801\u5e72\u51c0\u5229\u843d\u3002<\/p>\n<p>\u540e\u6765\u6211\u53d1\u73b0peft\u5e93\u5728\u8f83\u65b0\u7248\u672c\u91cc\u5df2\u7ecf\u652f\u6301\u81ea\u5b9a\u4e49\u521d\u59cb\u5316\u4e86\uff1a<\/p>\n<pre><code class=\"lang-python language-python python\">from peft import LoraConfig, get_peft_model\nfrom peft.tuners.lora import LoraConfig as PEFTLoraConfig\n\nlora_config = LoraConfig(\n    r=16,\n    lora_alpha=32,\n    lora_dropout=0.05,\n    target_modules=[&quot;q_proj&quot;, &quot;v_proj&quot;],\n    init_lora_weights=&quot;gaussian&quot;  # \u65b0\u7248\u652f\u6301\n)<\/code><\/pre>\n<p>\u5982\u679c\u4f60\u7528\u7684peft\u7248\u672c\u4e0d\u652f\u6301\uff0c\u4e5f\u53ef\u4ee5\u76f4\u63a5\u6539\u6e90\u7801\u3002\u6211\u4e2a\u4eba\u66f4\u503e\u5411\u76f4\u63a5fork\u4e00\u4efdpeft\u6309\u9700\u6539\uff1a<\/p>\n<pre><code class=\"lang-python language-python python\"># \u5728 peft\/tuners\/lora.py \u91cc\u7684 LoraLayer.reset_lora_parameters\n# \u539f\u6765\u662f\u8fd9\u6837\uff1a\ndef reset_lora_parameters(self):\n    if self.lora_A is not None:\n        # \u968f\u673a\u521d\u59cb\u5316\n        self.lora_A.weight.normal_(mean=0, std=0.02)\n    if self.lora_B is not None:\n        # \u96f6\u521d\u59cb\u5316\n        self.lora_B.weight.zero_()\n\n# \u6539\u6210\u8fd9\u6837\uff08\u7528\u6b63\u4ea4\u521d\u59cb\u5316\uff0c\u800c\u4e0d\u662f\u5168\u96f6\uff09\uff1a\ndef reset_lora_parameters(self):\n    if self.lora_A is not None:\n        self.lora_A.weight.normal_(mean=0, std=0.005)\n    if self.lora_B is not None:\n        # \u6539\u6210\u5c0f\u5e45\u5ea6\u7684\u6b63\u4ea4\u521d\u59cb\u5316\uff0c\u800c\u4e0d\u662f\u5168\u96f6\n        torch.nn.init.orthogonal_(self.lora_B.weight, gain=0.01)<\/code><\/pre>\n<p><strong>\u4e3a\u4ec0\u4e48\u6b63\u4ea4\u521d\u59cb\u5316\u6bd4\u65b9\u68481\uff08\u6539A\u65b9\u5dee\uff09\u548c\u65b9\u68483\uff08\u68af\u5ea6\u5747\u8861\uff09\u66f4\u597d\uff1f<\/strong><\/p>\n<ol>\n<li>\n<p>\u76f8\u6bd4\u65b9\u68481\uff08\u6539A\u65b9\u5dee\uff09\uff1a\u53ea\u6539A\u7684\u65b9\u5dee\u6ca1\u6709\u6253\u7834\u68af\u5ea6\u94fe\u5f0f\u4f9d\u8d56\uff0c\u53ea\u662f\u8ba9\u68af\u5ea6\u91cf\u7ea7\u5c0f\u4e00\u70b9\u3002\u6b63\u4ea4\u521d\u59cb\u5316\u662f\u540c\u65f6\u6539A\u548cB\uff0c\u8ba9B\u4e0d\u518d\u662f\u96f6\u77e9\u9635\uff0c\u8fd9\u6837\u524d\u5411\u65f6 $BA$ \u4e00\u5f00\u59cb\u5c31\u6709\u975e\u96f6\u8f93\u51fa\uff0c\u53cd\u5411\u65f6A\u548cB\u7684\u68af\u5ea6\u90fd\u80fd\u6b63\u5e38\u4f20\u5bfc\u3002<\/p>\n<\/li>\n<li>\n<p>\u76f8\u6bd4\u65b9\u68483\uff08\u68af\u5ea6\u5747\u8861\uff09\uff1a\u68af\u5ea6\u5747\u8861\u9700\u8981\u6bcf\u6b21backward\u540e\u624b\u52a8\u4fee\u6539\u68af\u5ea6\uff0c\u589e\u52a0\u4e86\u7ea615%\u7684\u8ba1\u7b97\u5f00\u9500\uff0c\u800c\u4e14\u4ee3\u7801\u4fb5\u5165\u6027\u5f3a\u3002\u6b63\u4ea4\u521d\u59cb\u5316\u662f\u4e00\u6b21\u6027\u6539\u521d\u59cb\u5316\uff0c\u8bad\u7ec3\u8fc7\u7a0b\u5b8c\u5168\u4e0d\u53d8\u3002<\/p>\n<\/li>\n<li>\n<p>\u6b63\u4ea4\u521d\u59cb\u5316\u7684\u6838\u5fc3\u4f18\u52bf\uff1a$B$ \u7528\u6b63\u4ea4\u521d\u59cb\u5316\u610f\u5473\u7740 $B$ \u7684\u5217\u5411\u91cf\u662f\u4e24\u4e24\u6b63\u4ea4\u7684\uff0c\u8fd9\u4fdd\u8bc1\u4e86 $BA$ \u7684\u8868\u793a\u7a7a\u95f4\u66f4\u4e30\u5bcc\uff0c\u4e0d\u4f1a\u51fa\u73b0\u67d0\u4e9b\u65b9\u5411\u88ab\u538b\u7f29\u7684\u60c5\u51b5\u3002<\/p>\n<\/li>\n<\/ol>\n<p><strong>\u6b63\u4ea4\u521d\u59cb\u5316\u7684\u53cd\u4f8b\u548c\u8fb9\u754c\u6761\u4ef6<\/strong>\uff1a<\/p>\n<p>\u4f46\u6b63\u4ea4\u521d\u59cb\u5316\u4e0d\u662f\u4e07\u80fd\u7684\u3002\u6709\u4e2a\u573a\u666f\u6211\u8e29\u8fc7\u5751\uff1a<\/p>\n<p>\u5982\u679c\u4f60\u7684\u4efb\u52a1\u9700\u8981\u6a21\u578b\u5f7b\u5e95\u6446\u8131\u9884\u8bad\u7ec3\u77e5\u8bc6\uff08\u6bd4\u5982\u505a\u98ce\u683c\u8fc1\u79fb\uff0c\u8ba9\u6a21\u578b\u5b8c\u5168\u5fd8\u6389\u539f\u6765\u7684\u5199\u4f5c\u98ce\u683c\uff09\uff0c\u6b63\u4ea4\u521d\u59cb\u5316\u53ef\u80fd\u4e0d\u5982\u96f6\u521d\u59cb\u5316\u597d\u3002\u56e0\u4e3a\u6b63\u4ea4\u521d\u59cb\u5316\u8ba9 $BA$ \u4e00\u5f00\u59cb\u5c31\u6709\u975e\u96f6\u8f93\u51fa\uff0c\u6a21\u578b\u4f1a\u66f4\u5bb9\u6613\u4fdd\u7559\u9884\u8bad\u7ec3\u77e5\u8bc6\u3002<\/p>\n<p>\u53e6\u5916\uff0c\u5bf9\u4e8e\u6781\u5c0f\u7684rank\uff08\u6bd4\u5982r=2\u6216r=4\uff09\uff0c\u6b63\u4ea4\u521d\u59cb\u5316\u7684\u610f\u4e49\u4e5f\u4e0d\u5927\u2014\u2014\u7a7a\u95f4\u592a\u5c0f\uff0c\u6b63\u4ea4\u7ea6\u675f\u548c\u666e\u901a\u521d\u59cb\u5316\u5dee\u522b\u4e0d\u5927\u3002\u8fd9\u79cd\u60c5\u51b5\u4e0b\u6211\u5efa\u8bae\u76f4\u63a5\u7528\u65b9\u68482\uff08\u8c03\u5927alpha\uff09\uff0c\u6216\u8005\u5e72\u8106\u6362\u522b\u7684\u5fae\u8c03\u65b9\u6cd5\u3002<\/p>\n<h2>\u8c03\u7528\u65b9\u5f0f<\/h2>\n<h3>\u5b8c\u6574\u8bad\u7ec3\u811a\u672c<\/h3>\n<p>\u8fd9\u662f\u6700\u7ec8\u7528\u7684\u8bad\u7ec3\u811a\u672c\uff0c\u57fa\u4e8eDeepSpeed ZeRO-2\uff1a<\/p>\n<pre><code class=\"lang-python language-python python\">#!\/usr\/bin\/env python\n# -*- coding: utf-8 -*-\n&quot;&quot;&quot;\nLoRA\u5fae\u8c03\u811a\u672c - \u652f\u6301\u68af\u5ea6\u76d1\u63a7\n&quot;&quot;&quot;\n\nimport os\nimport torch\nfrom transformers import (\n    AutoModelForCausalLM,\n    AutoTokenizer,\n    TrainingArguments,\n    Trainer,\n    DataCollatorForLanguageModeling\n)\nfrom peft import LoraConfig, get_peft_model, TaskType\nfrom datasets import load_dataset\n\n# ============== \u914d\u7f6e\u533a ==============\nMODEL_NAME = &quot;meta-llama\/Llama-2-7b-hf&quot;\nDATA_PATH = &quot;.\/data\/chat_data.json&quot;\nOUTPUT_DIR = &quot;.\/output\/lora_finetuned&quot;\n\n# LoRA\u914d\u7f6e\nLORA_R = 16\nLORA_ALPHA = 32\nLORA_DROPOUT = 0.05\nLORA_TARGET_MODULES = [&quot;q_proj&quot;, &quot;v_proj&quot;]\n\n# \u8bad\u7ec3\u914d\u7f6e\nLEARNING_RATE = 2e-4\nNUM_EPOCHS = 3\nBATCH_SIZE = 4\nGRADIENT_ACCUMULATION_STEPS = 4\nMAX_GRAD_NORM = 1.0\n# ============== \u914d\u7f6e\u533a\u7ed3\u675f ==============\n\ndef main():\n    # 1. \u52a0\u8f7d\u6a21\u578b\u548ctokenizer\n    print(&quot;Loading model...&quot;)\n    model = AutoModelForCausalLM.from_pretrained(\n        MODEL_NAME,\n        device_map=&quot;auto&quot;,\n        torch_dtype=torch.float16,\n        load_in_8bit=False,\n        trust_remote_code=True\n    )\n\n    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n    tokenizer.pad_token = tokenizer.eos_token\n    tokenizer.padding_side = &quot;right&quot;\n\n    # 2. \u914d\u7f6eLoRA\n    print(f&quot;Configuring LoRA: r={LORA_R}, alpha={LORA_ALPHA}&quot;)\n    lora_config = LoraConfig(\n        task_type=TaskType.CAUSAL_LM,\n        r=LORA_R,\n        lora_alpha=LORA_ALPHA,\n        lora_dropout=LORA_DROPOUT,\n        target_modules=LORA_TARGET_MODULES,\n        bias=&quot;none&quot;,\n        init_lora_weights=&quot;gaussian&quot;  # \u7528\u65b0\u7248\u652f\u6301\u7684\u521d\u59cb\u5316\n    )\n\n    model = get_peft_model(model, lora_config)\n    model.print_trainable_parameters()\n\n    # 3. \u52a0\u8f7d\u6570\u636e\n    print(f&quot;Loading data from {DATA_PATH}&quot;)\n    dataset = load_dataset(&quot;json&quot;, data_files=DATA_PATH, split=&quot;train&quot;)\n\n    def tokenize_function(examples):\n        texts = []\n        for conv in examples[&quot;conversations&quot;]:\n            text = &quot;&quot;&lt;|im_start|&gt;user\\n&quot; + conv[0][&quot;content&quot;]\n            text += &quot;&lt;|im_end|&gt;&lt;|im_start|&gt;assistant\\n&quot; + conv[1][&quot;content&quot;]\n            text += &quot;&lt;|im_end|&gt;&quot;\n            texts.append(text)\n\n        result = tokenizer(texts, truncation=True, max_length=512)\n        result[&quot;labels&quot;] = result[&quot;input_ids&quot;].copy()\n        return result\n\n    tokenized_dataset = dataset.map(\n        tokenize_function,\n        batched=True,\n        remove_columns=[&quot;conversations&quot;]\n    )\n\n    # 4. \u8bad\u7ec3\u53c2\u6570\n    training_args = TrainingArguments(\n        output_dir=OUTPUT_DIR,\n        learning_rate=LEARNING_RATE,\n        num_train_epochs=NUM_EPOCHS,\n        per_device_train_batch_size=BATCH_SIZE,\n        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n        max_grad_norm=MAX_GRAD_NORM,\n        warmup_ratio=0.03,\n        lr_scheduler_type=&quot;cosine&quot;,\n        logging_steps=10,\n        save_steps=100,\n        save_total_limit=2,\n        fp16=True,\n        dataloader_num_workers=4,\n        remove_unused_columns=False,\n        report_to=[&quot;tensorboard&quot;]\n    )\n\n    # 5. \u81ea\u5b9a\u4e49Trainer\uff0c\u8bb0\u5f55\u68af\u5ea6\n    class GradientLoggingTrainer(Trainer):\n        def training_step(self, model, inputs):\n            # \u6b63\u5e38\u8bad\u7ec3\n            loss = super().training_step(model, inputs)\n\n            # \u6bcf50\u6b65\u6253\u5370\u68af\u5ea6\u7edf\u8ba1\n            if self.state.global_step % 50 == 0:\n                grad_norms = {}\n                for name, param in model.named_parameters():\n                    if param.requires_grad and param.grad is not None:\n                        grad_norms[name] = param.grad.norm().item()\n\n                # \u6253\u5370A\u548cB\u7684\u68af\u5ea6\u6bd4\n                lora_A_norms = [v for k, v in grad_norms.items() if &quot;lora_A&quot; in k]\n                lora_B_norms = [v for k, v in grad_norms.items() if &quot;lora_B&quot; in k]\n\n                if lora_A_norms and lora_B_norms:\n                    avg_A = sum(lora_A_norms) \/ len(lora_A_norms)\n                    avg_B = sum(lora_B_norms) \/ len(lora_B_norms)\n                    ratio = avg_B \/ (avg_A + 1e-8)\n                    self.log({\n                        &quot;grad_norm_A&quot;: avg_A,\n                        &quot;grad_norm_B&quot;: avg_B,\n                        &quot;grad_ratio_B_A&quot;: ratio\n                    })\n\n            return loss\n\n    # 6. \u5f00\u59cb\u8bad\u7ec3\n    trainer = GradientLoggingTrainer(\n        model=model,\n        args=training_args,\n        train_dataset=tokenized_dataset,\n        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)\n    )\n\n    trainer.train()\n\n    # 7. \u4fdd\u5b58\n    trainer.save_model(f&quot;{OUTPUT_DIR}\/final&quot;)\n    tokenizer.save_pretrained(f&quot;{OUTPUT_DIR}\/final&quot;)\n    print(&quot;Training complete!&quot;)\n\nif __name__ == &quot;__main__&quot;:\n    main()<\/code><\/pre>\n<h3>\u4f7f\u7528DeepSpeed\u542f\u52a8<\/h3>\n<pre><code class=\"lang-bash language-bash bash\">deepspeed --num_gpus=8 train_lora.py \\\n    --deepspeed ds_config.json<\/code><\/pre>\n<p>ds_config.json\u5185\u5bb9\uff1a<\/p>\n<pre><code class=\"lang-json language-json json\">{\n  &quot;train_batch_size&quot;: &quot;auto&quot;,\n  &quot;train_micro_batch_size_per_gpu&quot;: &quot;auto&quot;,\n  &quot;gradient_accumulation_steps&quot;: &quot;auto&quot;,\n  &quot;gradient_clipping&quot;: 1.0,\n  &quot;zero_optimization&quot;: {\n    &quot;stage&quot;: 2,\n    &quot;offload_optimizer&quot;: {\n      &quot;device&quot;: &quot;cpu&quot;,\n      &quot;pin_memory&quot;: true\n    },\n    &quot;allgather_partitions&quot;: true,\n    &quot;allgather_bucket_size&quot;: 2e8,\n    &quot;reduce_scatter&quot;: true,\n    &quot;reduce_bucket_size&quot;: 2e8,\n    &quot;overlap_comm&quot;: true,\n    &quot;contiguous_gradients&quot;: true\n  },\n  &quot;fp16&quot;: {\n    &quot;enabled&quot;: true,\n    &quot;loss_scale&quot;: 0,\n    &quot;loss_scale_window&quot;: 1000,\n    &quot;initial_scale_power&quot;: 16\n  }\n}<\/code><\/pre>\n<h2>\u4e0a\u7ebf\u540e\u8bc4\u4f30<\/h2>\n<h3>\u9a8c\u8bc1\u7ed3\u679c<\/h3>\n<p>\u6539\u5b8c\u521d\u59cb\u5316\u4e4b\u540e\uff0c\u540c\u4e00\u4e2a\u6570\u636e\u96c6\u91cd\u65b0\u8bad\u7ec3\uff1a<\/p>\n<table>\n<thead>\n<tr>\n<th>\u9636\u6bb5<\/th>\n<th>\u539f\u6765(r=16,a=32,\u96f6\u521d\u59cb\u5316)<\/th>\n<th>\u6539\u540e(r=16,a=32,\u6b63\u4ea4\u521d\u59cb\u5316)<\/th>\n<\/tr>\n<\/thead>\n<tbody>\n<tr>\n<td>Step 100, loss<\/td>\n<td>2.1<\/td>\n<td>1.95<\/td>\n<\/tr>\n<tr>\n<td>Step 500, loss<\/td>\n<td>1.4<\/td>\n<td>1.12<\/td>\n<\/tr>\n<tr>\n<td>Step 1000, loss<\/td>\n<td>1.1<\/td>\n<td>0.78<\/td>\n<\/tr>\n<tr>\n<td>\u9a8c\u8bc1\u96c6BLEU<\/td>\n<td>12.3<\/td>\n<td>18.7<\/td>\n<\/tr>\n<tr>\n<td>\u9a8c\u8bc1\u96c6ROUGE-L<\/td>\n<td>0.241<\/td>\n<td>0.358<\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<p>BLEU\u4ece12.3\u8df3\u523018.7\uff0c\u8fd9\u624d\u5bf9\u561b\u3002loss\u964d\u4e0d\u4ee3\u8868\u6a21\u578b\u5728\u5b66\uff0c\u53ea\u6709\u9a8c\u8bc1\u6307\u6807\u63d0\u5347\u624d\u662f\u771f\u7684\u5b66\u5230\u4e1c\u897f\u3002<\/p>\n<h3>\u8d44\u6e90\u6d88\u8017\u5bf9\u6bd4<\/h3>\n<p>\u6539\u52a8\u540e\u5bf9\u663e\u5b58\u7684\u5f71\u54cd\uff1a<\/p>\n<table>\n<thead>\n<tr>\n<th>\u914d\u7f6e<\/th>\n<th>GPU\u663e\u5b58\u5360\u7528<\/th>\n<th>\u8bad\u7ec3\u901f\u5ea6<\/th>\n<\/tr>\n<\/thead>\n<tbody>\n<tr>\n<td>\u539f\u59cbLoRA (r=16)<\/td>\n<td>~42GB (8\u5361)<\/td>\n<td>120 tokens\/sec\/GPU<\/td>\n<\/tr>\n<tr>\n<td>\u6b63\u4ea4\u521d\u59cb\u5316LoRA<\/td>\n<td>~42GB (8\u5361)<\/td>\n<td>118 tokens\/sec\/GPU<\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<p>\u51e0\u4e4e\u6ca1\u533a\u522b\uff0c\u521d\u59cb\u5316\u53ea\u5f71\u54cd\u53c2\u6570\u503c\uff0c\u4e0d\u5f71\u54cd\u8ba1\u7b97\u56fe\u3002\u6b63\u4ea4\u521d\u59cb\u5316\u548c\u96f6\u521d\u59cb\u5316\u5728\u663e\u5b58\u5360\u7528\u548c\u901f\u5ea6\u4e0a\u662f\u4e00\u6837\u7684\u3002<\/p>\n<p>\u786c\u4ef6\u73af\u5883\uff1a8\u5361A100 80GB\uff0cUbuntu 22.04\uff0cCUDA 12.1\uff0cPyTorch 2.1.0\uff0ctransformers 4.36.0\uff0cpeft 0.7.1\u3002<\/p>\n<h3>\u68af\u5ea6\u76d1\u63a7\u7ed3\u679c<\/h3>\n<p>\u6539\u5b8c\u540e\u540c\u4e00\u6279\u6570\u636e\u518d\u770b\u68af\u5ea6\u5206\u5e03\uff1a<\/p>\n<pre><code>[lora_B.weight] grad_norm=0.0218\n[lora_A.weight] grad_norm=0.0195  # \u4ece0.0001\u5347\u52300.0195\uff01<\/code><\/pre>\n<p>A\u548cB\u7684\u68af\u5ea6\u6bd4\u4ece200:1\u53d8\u6210\u4e86\u7ea61:1\uff0c\u8fd9\u624d\u662f\u5065\u5eb7\u7684\u8bad\u7ec3\u72b6\u6001\u3002<\/p>\n<h2>\u5e38\u89c1\u5751<\/h2>\n<h3>\u5982\u679c\u8fd9\u6837\u6539\u8fd8\u4e0d\u884c\u600e\u4e48\u529e<\/h3>\n<p>\u6539\u4e86\u521d\u59cb\u5316\u4e4b\u540e\u9a8c\u8bc1\u6307\u6807\u8fd8\u662f\u4e0d\u52a8\uff0c\u8fd9\u65f6\u5019\u522b\u614c\uff0c\u95ee\u9898\u53ef\u80fd\u5728\u522b\u7684\u5730\u65b9\u3002<\/p>\n<p><strong>1. \u5b66\u4e60\u7387\u4e0d\u5bf9<\/strong><\/p>\n<p>\u6211\u9047\u5230\u8fc7\u5b66\u4e60\u7387\u592a\u5927\u5bfc\u81f4\u8bad\u7ec3\u9707\u8361\u7684\u60c5\u51b5\u30022e-4\u5bf97B\u6a21\u578b\u6765\u8bf4\u7b97\u6bd4\u8f83\u5927\u7684\uff0c\u5982\u679c\u7528\u4e86\u6b63\u4ea4\u521d\u59cb\u5316\uff0c\u68af\u5ea6\u66f4\u5927\u4e86\uff0c\u5efa\u8bae\u628a\u5b66\u4e60\u7387\u964d\u52301e-4\u8bd5\u8bd5\u3002<\/p>\n<pre><code class=\"lang-python language-python python\">training_args = TrainingArguments(\n    learning_rate=1e-4,  # \u4ece2e-4\u964d\u4e0b\u6765\n    ...\n)<\/code><\/pre>\n<p><strong>2. \u6570\u636e\u8d28\u91cf\u95ee\u9898<\/strong><\/p>\n<p>\u524d\u9762\u8bf4\u7684token\u5206\u5e03\u95ee\u9898\u518d\u5f3a\u8c03\u4e0b\u3002\u5982\u679c\u6570\u636e\u91cc\u67d0\u4e2atoken\u5360\u6bd4\u8d85\u8fc730%\uff0c\u6a21\u578b\u4f1a&#8221;\u5077\u61d2&#8221;\uff0c\u53ea\u5b66\u9884\u6d4b\u8fd9\u4e2atoken\u3002\u68c0\u67e5\u4e0bspecial token\u6bd4\u4f8b\uff0c\u8d85\u8fc75%\u5c31\u6709\u95ee\u9898\u3002<\/p>\n<p><strong>3. \u6570\u636e\u91cf\u592a\u5c11<\/strong><\/p>\n<p>3000\u6761\u6570\u636e\u5176\u5b9e\u504f\u5c11\uff0c\u8bad3\u4e2aepoch\u53ef\u80fd\u4e0d\u591f\u3002\u6211\u540e\u6765\u628aepoch\u63d0\u52305\uff0c\u9a8c\u8bc1\u6307\u6807\u53c8\u6da8\u4e86\u4e00\u70b9\u3002<\/p>\n<p><strong>4. \u6a21\u578b\u672c\u6765\u5c31\u4e0d\u9002\u5408\u8fd9\u4e2a\u4efb\u52a1<\/strong><\/p>\n<p>\u8bf4\u5b9e\u8bdd\uff0c\u6709\u65f6\u5019\u95ee\u9898\u4e0d\u662f\u8bad\u7ec3\u7684\u95ee\u9898\uff0c\u662f\u6a21\u578b\u672c\u8eab\u7684\u80fd\u529b\u8fb9\u754c\u3002\u6bd4\u5982\u4f60\u7528\u4e2d\u6587\u6a21\u578b\u505a\u82f1\u6587\u4efb\u52a1\uff0c\u6216\u80057B\u6a21\u578b\u505a\u590d\u6742\u63a8\u7406\uff0c\u6548\u679c\u5dee\u662f\u6b63\u5e38\u7684\uff0c\u4e0d\u662fLoRA\u7684\u9505\u3002\u8fd9\u79cd\u60c5\u51b5\u6362\u4e2a\u5927\u6a21\u578b\u6216\u8005\u4e13\u7528\u6a21\u578b\u66f4\u5b9e\u5728\u3002<\/p>\n<p><strong>5. target_modules\u9009\u9519\u4e86<\/strong><\/p>\n<p>\u5982\u679c\u4f60\u53ea\u9009\u4e86q_proj\u548cv_proj\uff0c\u4f46\u4efb\u52a1\u9700\u8981\u7406\u89e3\u8bed\u4e49\u5173\u7cfb\uff0c\u53ef\u80fd\u52a0\u4e0ak_proj\u548co_proj\u6548\u679c\u66f4\u597d\u3002\u5982\u679c\u4efb\u52a1\u504f\u751f\u6210\uff0c\u52a0\u4e0affn\u5c42\u7684gate_proj\u548cup_proj\u53ef\u80fd\u6709\u5947\u6548\u3002<\/p>\n<h3>\u53e6\u4e00\u4e2a\u5bb9\u6613\u8e29\u7684\u5751\uff1a\u6570\u636etoken\u5206\u5e03<\/h3>\n<p>\u6211\u540e\u6765\u53c8\u9047\u5230\u8fc7\u4e00\u6b21\u7c7b\u4f3c\u95ee\u9898\uff0c\u8fd9\u6b21\u4e0d\u662f\u521d\u59cb\u5316\u95ee\u9898\uff0c\u800c\u662f\u6570\u636e\u95ee\u9898\u3002<\/p>\n<p>\u6570\u636e\u96c6\u91cc\u67d0\u4e2a\u7279\u6b8atoken\u5360\u6bd4\u8d85\u8fc740%\uff0c\u5bfc\u81f4\u6a21\u578b\u53ea\u8981\u9884\u6d4b\u8fd9\u4e2atoken\u5c31\u80fd\u628aloss\u62c9\u4f4e\uff0c\u4f46\u5b9e\u9645\u4e0a\u4ec0\u4e48\u90fd\u6ca1\u5b66\u5230\u3002\u8fd9\u79cd\u60c5\u51b5\u4e0b\u68af\u5ea6\u6d88\u5931\u662f\u8868\u8c61\uff0c\u771f\u6b63\u7684\u95ee\u9898\u662f\u6570\u636e\u4e0d\u5e73\u8861\u3002<\/p>\n<h2>\u603b\u7ed3<\/h2>\n<p>\u56de\u8fc7\u5934\u770b\u8fd9\u4e2a\u95ee\u9898\uff0c\u5176\u5b9e\u4e0d\u590d\u6742\uff1a<\/p>\n<ol>\n<li>\n<p><strong>LoRA\u7684B\u96f6\u521d\u59cb\u5316+A\u968f\u673a\u521d\u59cb\u5316 = A\u88ab\u9501\u6b7b<\/strong>\u3002\u8fd9\u4e0d\u662fbug\uff0c\u662f\u8bbe\u8ba1\u9009\u62e9\uff0c\u4f46\u5728\u67d0\u4e9b\u573a\u666f\u4e0b\u4f1a\u5bfc\u81f4\u8bad\u7ec3\u65e0\u6548\u3002<\/p>\n<\/li>\n<li>\n<p><strong>loss\u4e0b\u964d\u4e0d\u7b49\u4e8e\u6a21\u578b\u5728\u5b66<\/strong>\u3002\u5f97\u770b\u9a8c\u8bc1\u6307\u6807\u3001\u68af\u5ea6\u5206\u5e03\u3001optimizer\u72b6\u6001\uff0c\u7efc\u5408\u5224\u65ad\u3002<\/p>\n<\/li>\n<li>\n<p><strong>rank\u548calpha\u4e0d\u662f\u8d8a\u5927\u8d8a\u597d<\/strong>\uff0c\u4f46\u592a\u5c0f\u4f1a\u5bfc\u81f4\u8868\u793a\u80fd\u529b\u4e0d\u8db3\u3002\u6211\u73b0\u5728\u4e00\u822cr=16\u8d77\u6b65\uff0c\u5982\u679c\u4efb\u52a1\u590d\u6742\u6216\u8005\u6570\u636e\u91cf\u5927\u624d\u5f8032\u300164\u8c03\u3002<\/p>\n<\/li>\n<li>\n<p>\u6570\u636e\u5206\u5e03\u95ee\u9898\u4e5f\u5f97\u6392\u67e5\uff0c\u7279\u522b\u662f\u7528chat\u6a21\u677f\u7684\u6570\u636e\u96c6\uff0csystem prompt\u3001user\/assistant\u6807\u8bb0\u7684\u5206\u5e03\u4f1a\u76f4\u63a5\u5f71\u54cd\u67d0\u4e9b\u5c42\u7684\u6fc0\u6d3b\u3002<\/p>\n<\/li>\n<li>\n<p>\u6b63\u4ea4\u521d\u59cb\u5316\u662f\u6b63\u7edf\u89e3\u6cd5\uff0c\u6bd4\u6539\u65b9\u5dee\u548c\u68af\u5ea6\u5747\u8861\u90fd\u5e72\u51c0\u5229\u843d\u3002\u9664\u975e\u4f60\u6709\u7279\u6b8a\u9700\u6c42\uff08\u5f7b\u5e95\u5fd8\u6389\u9884\u8bad\u7ec3\u77e5\u8bc6\uff09\uff0c\u5426\u5219\u522b\u7528\u96f6\u521d\u59cb\u5316\u3002<\/p>\n<\/li>\n<\/ol>\n<p>\u6700\u540e\u63d0\u9192\u4e00\u53e5\uff1a\u5982\u679c\u4f60\u73b0\u5728\u7528\u7684peft\u7248\u672c\u6bd4\u8f83\u8001\uff0c\u5347\u7ea7\u52300.7+\u4f1a\u597d\u5f88\u591a\uff0c\u90a3\u4e2a\u7248\u672c\u5bf9\u521d\u59cb\u5316\u505a\u4e86\u4f18\u5316\uff0c\u800c\u4e14\u591a\u4e86gradient checkpointing\u7684\u517c\u5bb9\u5904\u7406\u3002<\/p>","protected":false},"excerpt":{"rendered":"<p>loss\u660e\u660e\u5728\u964d\uff0c\u9a8c\u8bc1\u96c6\u6307\u6807\u5374\u4e00\u52a8\u4e0d\u52a8\u3002\u67e5\u4e86\u534a\u5929\u53d1\u73b0\u53ea\u6709\u5c11\u6570\u51e0\u5c42LoRA\u6743\u91cd\u5728\u66f4\u65b0\uff0c\u5176\u4ed6\u5c42\u68af\u5ea6\u51e0\u4e4e\u4e3a\u96f6\u3002\u8fd9\u4e2a\u5751\u82b1\u4e86\u6211\u4e24\u5929\u624d\u5b9a\u4f4d\u6e05\u695a\uff0c\u5206\u4eab\u4e0b\u6392\u67e5\u601d\u8def\u548c\u5177\u4f53\u64cd\u4f5c\u3002<\/p>","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[401],"tags":[402,144,180,467,530,458],"class_list":["post-708","post","type-post","status-publish","format-standard","hentry","category-ai","tag-lora","tag-pytorch","tag-transformer","tag-467","tag-530","tag-458"],"views":4,"_links":{"self":[{"href":"https:\/\/www.liaoxinghui.com\/index.php?rest_route=\/wp\/v2\/posts\/708","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.liaoxinghui.com\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.liaoxinghui.com\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.liaoxinghui.com\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.liaoxinghui.com\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=708"}],"version-history":[{"count":1,"href":"https:\/\/www.liaoxinghui.com\/index.php?rest_route=\/wp\/v2\/posts\/708\/revisions"}],"predecessor-version":[{"id":720,"href":"https:\/\/www.liaoxinghui.com\/index.php?rest_route=\/wp\/v2\/posts\/708\/revisions\/720"}],"wp:attachment":[{"href":"https:\/\/www.liaoxinghui.com\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=708"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.liaoxinghui.com\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=708"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.liaoxinghui.com\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=708"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}