如何修改大模型的位置编码 --以LLama为例




看了上面的公式,我在考虑为什么需要建立 λ \lambda λ和k之间的关系?

因为我们要修改 β \beta β进制为 β λ \beta\lambda βλ,由于k我们是可以知道的比如我们需要把位置编码缩小为10倍,直接设置k为10,但是采用NTK的方式,维度缩小为10倍,那么我们就不确定, λ \lambda λ怎么设置了。所以需要简历 λ \lambda λ和k之间的关系,从上图可知, λ = k 2 / ( d − 2 ) \lambda=k^{2/(d-2)} λ=k2/(d2)

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # For BC we register cos and sin cached
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        t = t / self.scaling_factor
        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
        self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)

    def sin_cached(self):
            "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
            "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
        return self._sin_cached

    def cos_cached(self):
            "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
            "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
        return self._cos_cached

    def forward(self, x, position_ids):
        # x: [bs, num_attention_heads, seq_len, head_size]
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

λ \lambda λ和k之间的关系,那么代码怎么实现呢,我们只需要修改 β λ \beta\lambda βλ的结果即可,其中 β \beta β 1000 0 2 / d 10000^{2/d} 100002/d

import transformers

old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):

    #The method is just these three lines
    max_position_embeddings = 16384
    k = 8 #Alpha value
    base = base * k ** (dim / (dim-2)) #Base change formula

    old_init(self, dim, max_position_embeddings, base, device)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init

为什么采用base = base * k ** (dim / (dim-2)),原始的base为10000, 而 β \beta β 1000 0 2 / d 10000^{2/d} 100002/d,发现在代码里面修改的仅仅是base的结果, β \beta β= b a s e 2 / d base^{2/d} base2/d,而 λ = k 2 / ( d − 2 ) \lambda=k^{2/(d-2)} λ=k2/(d2),我们需要把k和base进行融合,修改成,base*k的形式形成新的base, λ \lambda λ等于k的指数 2 / ( d − 2 ) ∗ d / 2 ∗ 2 / d = d / ( d − 2 ) ∗ 2 / d 2/(d-2)*d/2*2/d=d/(d-2)*2/d 2/(d2)d/22/d=d/(d2)2/d λ = ( k d / ( d − 2 ) ) 2 / d \lambda=(k^{d/(d-2)})^{2/d} λ=(kd/(d2))2/d,因为2/d在RoPE的代码里面已经计算过了:

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)

所以我们则赋值新的base为base * k ** (dim / (dim-2))。

Dynamic NTK


class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def forward(self, x, position_ids):
        # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_position_embeddings:#只有长度超过了预训练的阈值,进行NTK
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            inv_freq = 1.0 / (
                base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: this may break with compilation

        cos, sin = super().forward(x, position_ids)
        return cos, sin


