如何修改大模型的位置编码 --以LLama为例

最近在看RoPE相关内容,一些方法通过简单修改位置编码就可以无需训练支持更长的文本内容。由于一些模型,已经训练好了,但是怎么修改已经训练好的模型位置编码。查了以下相关代码,记录一下。原理这里就不细讲了,贴几个相关博客。
十分钟读懂旋转编码(RoPE)
Transformer升级之路:11、将β进制位置进行到底
Transformer升级之路:10、RoPE是一种β进制编码

NTK

下图为NTK的原理证明:截取自Transformer升级之路:10、RoPE是一种β进制编码
在这里插入图片描述
在这里插入图片描述

看了上面的公式,我在考虑为什么需要建立 λ \lambda λ和k之间的关系?

因为我们要修改 β \beta β进制为 β λ \beta\lambda βλ,由于k我们是可以知道的比如我们需要把位置编码缩小为10倍,直接设置k为10,但是采用NTK的方式,维度缩小为10倍,那么我们就不确定, λ \lambda λ怎么设置了。所以需要简历 λ \lambda λ和k之间的关系,从上图可知, λ = k 2 / ( d − 2 ) \lambda=k^{2/(d-2)} λ=k2/(d2)
下面开始理解如何修改RoPE为NTK的形式:
以下为LLama的RoPE代码实现

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        super().__init__()
        self.scaling_factor = scaling_factor
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # For BC we register cos and sin cached
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        t = t / self.scaling_factor
        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
        self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)

    @property
    def sin_cached(self):
        logger.warning_once(
            "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
            "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
        )
        return self._sin_cached

    @property
    def cos_cached(self):
        logger.warning_once(
            "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
            "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
        )
        return self._cos_cached

    @torch.no_grad()
    def forward(self, x, position_ids):
        # x: [bs, num_attention_heads, seq_len, head_size]
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

λ \lambda λ和k之间的关系,那么代码怎么实现呢,我们只需要修改 β λ \beta\lambda βλ的结果即可,其中 β \beta β 1000 0 2 / d 10000^{2/d} 100002/d
参考代码为:点击

import transformers

old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):

    #The method is just these three lines
    max_position_embeddings = 16384
    k = 8 #Alpha value
    base = base * k ** (dim / (dim-2)) #Base change formula

    old_init(self, dim, max_position_embeddings, base, device)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init

为什么采用base = base * k ** (dim / (dim-2)),原始的base为10000, 而 β \beta β 1000 0 2 / d 10000^{2/d} 100002/d,发现在代码里面修改的仅仅是base的结果, β \beta β= b a s e 2 / d base^{2/d} base2/d,而 λ = k 2 / ( d − 2 ) \lambda=k^{2/(d-2)} λ=k2/(d2),我们需要把k和base进行融合,修改成,base*k的形式形成新的base, λ \lambda λ等于k的指数 2 / ( d − 2 ) ∗ d / 2 ∗ 2 / d = d / ( d − 2 ) ∗ 2 / d 2/(d-2)*d/2*2/d=d/(d-2)*2/d 2/(d2)d/22/d=d/(d2)2/d λ = ( k d / ( d − 2 ) ) 2 / d \lambda=(k^{d/(d-2)})^{2/d} λ=(kd/(d2))2/d,因为2/d在RoPE的代码里面已经计算过了:

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)

所以我们则赋值新的base为base * k ** (dim / (dim-2))。

Dynamic NTK

Dynamic在NTK的基础上进行简单的修改,采用NTK的时候更加灵活。
截图源于:RoPE到底是何方神圣(数学推理+优化方法)
在这里插入图片描述
代码实现:

class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def forward(self, x, position_ids):
        # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_position_embeddings:#只有长度超过了预训练的阈值,进行NTK
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            inv_freq = 1.0 / (
                base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
            )
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: this may break with compilation

        cos, sin = super().forward(x, position_ids)
        return cos, sin

相关推荐

  1. 调试ffmpeg,演示gdb如何定位内存被修改

    2024-03-25 13:54:04       39 阅读
  2. 模型llama.cp编译

    2024-03-25 13:54:04       28 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-25 13:54:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-25 13:54:04       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-25 13:54:04       82 阅读
  4. Python语言-面向对象

    2024-03-25 13:54:04       91 阅读

热门阅读

  1. 封装的charts使用 vue2

    2024-03-25 13:54:04       37 阅读
  2. Springboot vue elementui 电子商城系统源码

    2024-03-25 13:54:04       35 阅读
  3. 蓝桥杯刷题--python-31-单调栈

    2024-03-25 13:54:04       38 阅读
  4. 2457. 美丽整数的最小增量

    2024-03-25 13:54:04       41 阅读
  5. 面试宝典:MySQL中索引为什么使用B+树的深度分析

    2024-03-25 13:54:04       36 阅读
  6. es同义词配置规则

    2024-03-25 13:54:04       44 阅读
  7. 天秀基础算法 - 二分查找和二分答案

    2024-03-25 13:54:04       33 阅读
  8. SpringCloud优缺点及适合场景

    2024-03-25 13:54:04       40 阅读
  9. npm 包管理工具:常用命令详解与使用指南

    2024-03-25 13:54:04       35 阅读
  10. kingbaseESV8分区表

    2024-03-25 13:54:04       35 阅读
  11. Github 2024-03-21 开源项目日报 Top10

    2024-03-25 13:54:04       30 阅读
  12. 计算方法(第3版)——学习笔记(一)

    2024-03-25 13:54:04       34 阅读
  13. Python之关键字传参(**kwargs)妙处

    2024-03-25 13:54:04       26 阅读