大模型都在用的:旋转位置编码

写在前面

        这篇文章提到了绝对位置编码和相对位置编码,但是他们都有局限性,比如绝对位置编码不能直接表征token的相对位置关系;相对位置编码过于复杂,影响效率。于是诞生了一种用绝对位置编码的方式实现相对位置编码的编码方式——旋转位置编码(Rotary Position Embedding, RoPE),兼顾效率和相对位置关系。

        RoPE的核心思想是通过旋转的方式将位置信息编码到每个维度,从而使得模型能够捕捉到序列中元素的相对位置信息。现在已经在很多大模型证明了其有效性,比如ChatGLM、LLaMA等。

一、RoPE的优点

1.真正的旋转位置编码

        Transformer的原版位置编码也使用了三角函数,但它生成的是每个位置的绝对编码,三角函数的主要用途是生成具有可区分性的周期性模式,也没有应用旋转变换的概念,因此属于绝对位置编码。同时原版的编码使用加法,在多层传递后导致位置信息的稀释,如下图 (没想到这张图也有被当做反面典型的时候吧):

        RoPE不是简单的加法,而是通过复数乘法实现旋转变换,这种旋转是将位置信息融入到token表示中的关键机制。RoPE在实现过程中通过乘法操作融入位置信息,与模型中的Q和K深度融合,将旋转操作真正植入Attention机制内部,强化了位置编码信息的作用

2.更好的相对位置信息编码

        注意力机制通过计算Embedding的内积来确定它们之间的关系强度。

        使用RoPE时,两个位置的编码通过旋转变换后的内积,自然地包含了它们之间的相对位置信息。这是因为旋转操作保持了内积的性质,使得内积计算不仅反映了token的内容相似性,还反映了它们的位置关系。

3.更适用于多维输入

        这点很有意思,传统的Transformer位置编码主要针对一维序列,如文本序列。然而,在某些任务中,输入可能是二维或更高维的数据,如图像或视频数据。旋转位置编码可以更灵活地应用于多维输入数据,通过对不同维度的位置信息进行编码,使得模型能够更好地理解多维数据中的位置关系。

4. 更善于处理长序列

        RoPE可以减少位置信息的损失。在深层网络中,RoPE通过乘法操作融入位置信息,乘法操作有助于在深层网络中保持位置信息的完整性。在处理一个长文本时,RoPE通过在每一层的自注意力计算中使用旋转变换,确保了位置信息能够被有效保留和利用,即使是在模型的较深层次。

二、公式

        既然旋转的位置编码有这么多优点,那怎么实现位置编码的旋转呢,其实网上有很多介绍的文章。大概意思就是复数可以通过乘以e的幂来旋转角度,其中幂就是角度,再结合欧拉公式推出三角函数的表达,大致流程如下。

        欧拉公式:

e^{i\theta }=cos\theta +i\cdot sin\theta        (1)

        复数旋转角度θ:

(x+y\cdot i)e^{i\theta }                (2)

        将(1)带入(2):

(x+y\cdot i)e^{i\theta }=(xcos\theta -ysin\theta )+i(xsin\theta +ycos\theta )        (3)

        这块东西苏剑林老师已经从数学的角度进行过很深入的推导,这里的融合式部分,我就不班门弄斧了。我今天提供一种朴素的思考过程,从代码实现的角度思考如何进行旋转

        众所周知,一维向量是不能旋转的,那我们就旋转一个[2,d]的二维向量q,并且设x=q[0],y=q[1]即:

x=[q_0,q_1...,q_{d/2-1}],y=[q_{d/2},q_{d/2+1}...,q_{d-1}]        (4)

        要旋转q很容易,乘以旋转矩阵就可以了,如果我们要旋转角度θ:

R(\theta )=[x , y]\cdot \begin{bmatrix} cos(\theta ) & -sin(\theta )\\ sin(\theta ) & cos(\theta ) \end{bmatrix}                (5)

        展开之后,结果如下:

\begin{bmatrix} q_0 \\ ... \\ q_{d/2-1}\\ q_{d/2}\\ ...\\ q_{d-1} \end{bmatrix} \bigotimes \begin{bmatrix} cos\theta _0 \\ ... \\ cos\theta _{d-2}\\ cos\theta _{0}\\ ...\\ cos\theta _{d-2} \end{bmatrix} + \begin{bmatrix} -q_{d/2} \\ ... \\ -q_{d-1}\\ q_{0}\\ ...\\ q_{d/2-1} \end{bmatrix} \bigotimes \begin{bmatrix} cos\theta _0 \\ ... \\ cos\theta _{d-2}\\ cos\theta _{0}\\ ...\\ cos\theta _{d-2} \end{bmatrix}        (6)

        上面的\theta = \frac{pos}{10000^{\frac{2i}{d_{model}}}},很眼熟吧,就是沿用了transformer的机制,这里有详细的介绍。

        而且大家看到字母q也大概能猜到,这就是Attention中的Q,同样的操作也可以对K使用。经过上述操作,其实已经以旋转的方式将位置编码融合到Attention机制内部。

        下面就是根据式子(6)的代码实现了。这里提前说一句,ChatGLM的Q和K的形状都是[b,1,32,64],其中b是token_ids的长度;32是multi-head的个数;64将被拆成两部分,每部分32,也就是上面的x,y,下面开始代码实现部分。

三、代码实现

        我们以ChatGLM的代码为例,展示一下RoPE的使用,以下代码都在modeling_chatglm.py文件中,一条训练数据:

{"context": "你好", "target": "你好,我是大白话"}

1.字符串转换成token_ids

[ 5,  74874, 130001, 130004,  5,  74874, 6,  65806,  63850, 95351, 130005]

2.计算position_ids

        根据上面的token_ids计算出position_ids:

[[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2],
 [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]]

        解释一下position_ids:第一行表示序列中每个元素的全局位置,第一个“2”表明context结束了,target要开始了,后面所有的2都是target部分;第二行则细化到更具体的局部位置,从1开始表征整个target的内容,这样用两个维度的编码很优雅的体现了context和target,这种层次化处理对于理解上下文非常重要。

        代码如下:

    def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
        """
        根据token_ids生成position_ids
        :param input_ids: 这里是[[ 5, 74874, 130001, 130004, 5, 74874, 6, 65806, 63850, 95351, 130005]]
        :param mask_positions: 2 输出的第1维mask掉几位,即这一位及其前面都是0,后面是1,2...
        :param device:
        :param use_gmasks:
        :return: [[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2],
                    [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]]
        """
        batch_size, seq_length = input_ids.shape
        if use_gmasks is None:
            use_gmasks = [False] * batch_size
        context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
        if self.position_encoding_2d:
            # 会走这一分支
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
            for i, context_length in enumerate(context_lengths):
                position_ids[i, context_length:] = mask_positions[i]
            block_position_ids = [torch.cat((
                torch.zeros(context_length, dtype=torch.long, device=device),
                torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
            )) for context_length in context_lengths]
            block_position_ids = torch.stack(block_position_ids, dim=0)
            position_ids = torch.stack((position_ids, block_position_ids), dim=1)
        else:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
            for i, context_length in enumerate(context_lengths):
                if not use_gmasks[i]:
                    position_ids[i, context_length:] = mask_positions[i]

        return position_ids

3.角度序列Embedding

        接下来,将position_ids转换成角度序列Embedding,下表中每个格的公式为

\theta_i = m\cdot \frac{1}{10000^\frac{2\cdot i}{d}}

        其中m是position_ids中元素的数值;i是编码的索引,ChatGLM使用两个0-31拼接;d是维度,hidden_size // (num_attention_heads * 2)=46:

        第一部分:position_ids=[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2],每个值编码成长度64的角度序列:

m | i01310131
0m=0, i=0m=0, i=1...m=0, i=31m=0, i=0m=0, i=1...m=0, i=31
1m=1, i=0m=1, i=1m=1, i=31m=1, i=0m=1, i=1m=1, i=31
2m=2, i=0m=2, i=1m=2, i=31m=2, i=0m=2, i=1m=2, i=31
...
2m=2, i=0m=2, i=1m=2, i=31m=2, i=0m=2, i=1m=2, i=31

        第二部分:block_position_ids=[0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]

m | i01310131
0m=0, i=0m=0, i=1...m=0, i=31m=0, i=0m=0, i=1...m=0, i=31
0m=0, i=0m=0, i=1...m=0, i=31m=0, i=0m=0, i=1...m=0, i=31
0m=0, i=0m=0, i=1...m=0, i=31m=0, i=0m=0, i=1...m=0, i=31
1m=1, i=0m=1, i=1m=1, i=31m=1, i=0m=1, i=1m=1, i=31
...
8m=8, i=0m=8, i=1m=8, i=31m=8, i=0m=8, i=1m=8, i=31

代码如下:

class RotaryEmbedding(torch.nn.Module):
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                              error_msgs):
        pass

    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
        """
        根据position_ids计算旋转角度的Embedding
        :param dim: 这里hidden_size // (num_attention_heads * 2)=46,其中hidden_size=4096 num_attention_heads=32
        :param base:
        :param precision:
        :param learnable:
        """
        super().__init__()
        # 初始化“频率”,可以理解为position_id每增加1,增加的角度,是Embedding形式的。
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        inv_freq = inv_freq.half()
        self.learnable = learnable
        if learnable:
            self.inv_freq = torch.nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        else:
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None
        self.precision = precision

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len
            # 1.对position_ids去重并正序排列得到t,如:[[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]] --> t=[[0, 1, 2]]
            t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
            # 2.t与初始化好的“频率”做外积,得到每个position_id的角度,是Embedding
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            # 3.每个Embedding重复叠加一次
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # 4.算cos和sin,并增加维度
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]


def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
    # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
    # 类似于查表,根据每个position_id获取相应的Embedding
    cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
        F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
    ......

4.截取拼接Q和K

        这一步对Q或者K做截断,并将第二段取反拼在第一段的前面,拼接成公式第二项的q部分。

上述3、4流程示意图:

代码如下:

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=x1.ndim - 1)  

5.旋转位置编码融合

        将旋转位置编码融合到Q和K中,计算第一部分的cos(\theta1)和sin(\theta1),并与输入的Q1、K1做乘法融合;计算第二部分的cos(\theta1)和sin(\theta1),并与输入的Q1、K1做乘法融合,最后将Q和K分别拼接,组成融合了旋转位置编码的新Q和K。整体流程图如下,其中rotary_pos_emb是上图,也就是步骤3、4:

代码如下:

def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
    # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
    # 类似于查表,根据每个position_id获取相应的Embedding
    cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
        F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
    # 执行旋转位置编码与QK的融合
    q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
    return q, k


# 整体流程如下
# 1.拆分出Q1、Q2、K1、K2
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
# 2.计算旋转Embedding
cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
                position_ids[:, 1, :].transpose(0, 1).contiguous()
# 3.旋转位置编码融合
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
# 4.将拆分出的Q1、Q2、K1、K2合并成新的Q、K
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))

        位置编码对于Transformer的重要性毋庸置疑,旋转位置编码也确实解决了一些问题。最有意思的就是它是一个二维编码,将旋转信息通过乘法操作融入Attention机制内部,强化了位置编码信息,现在已经有很多开源大模型都使用了旋转位置编码,可见其效果不俗。

        旋转位置编码就介绍到这里,关注不迷路(#^.^#)

关注订阅号了解更多精品文章

交流探讨、商务合作请加微信

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/575181.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LS2K1000LA基础教程

基于LS2K1000LA的基础教程 by 南京工业大学 孙冬梅 于 2024.4.25 文章目录 基于LS2K1000LA的基础教程一、目的二、平台1.硬件平台2.软件平台 三、测试0.开发板开机及编译器配置0.1 开发板控制台0.2 虚拟机编译器配置 1. 简单应用编程1.helloworld.c2. fileio 文件操作3.proce…

Scrapy 爬虫教程:从原理到实战

Scrapy 爬虫教程:从原理到实战 一、Scrapy框架简介 Scrapy是一个由Python开发的高效网络爬虫框架,用于从网站上抓取数据并提取结构化信息。它采用异步IO处理请求,能够同时发送多个请求,极大地提高了爬虫效率。 二、Scrapy运行原…

入坑 Java

原文:https://blog.iyatt.com/?p11305 前言 今天(2023.8.31)有个学长问我接不接一个单子,奈何没学过 Java,本来不打算接的。只是报酬感觉还不错,就接了。 要求的完成时间是在10月初,总共有一…

Spring Boost + Elasticsearch 实现检索查询

需求:对“昵称”进行“全文检索查询”,对“账号”进行“精确查询”。 认识 Elasticsearch 1. ES 的倒排索引 正向索引 对 id 进行检索速度很快。对其他字段即使加了索引,只能满足精确查询。模糊查询时,逐条数据扫描&#xff0c…

编译原理实验课

本人没咋学编译原理,能力有限,写的不好轻点喷,大佬路过的话,那你就路过就好 东大编译原理实验课原题,22年 1. 基本题:简单的扫描器设计 【问题描述】 熟悉并实现一个简单的扫描器,设计扫描器…

C++ | Leetcode C++题解之第49题字母异位词分组

题目&#xff1a; 题解&#xff1a; class Solution { public:vector<vector<string>> groupAnagrams(vector<string>& strs) {// 自定义对 array<int, 26> 类型的哈希函数auto arrayHash [fn hash<int>{}] (const array<int, 26>&…

黑马点评(十二) -- UV统计

一 . UV统计-HyperLogLog 首先我们搞懂两个概念&#xff1a; UV&#xff1a;全称Unique Visitor&#xff0c;也叫独立访客量&#xff0c;是指通过互联网访问、浏览这个网页的自然人。1天内同一个用户多次访问该网站&#xff0c;只记录1次。 PV&#xff1a;全称Page View&…

linux权限维持(四)

6.inetd服务后门 inetd 是一个监听外部网络请求 ( 就是一个 socket) 的系统守护进程&#xff0c;默认情况下为 13 端口。当 inetd 接收到 一个外部请求后&#xff0c;它会根据这个请求到自己的配置文件中去找到实际处理它的程序&#xff0c;然后再把接收到的 这个socket 交给那…

机器学习 -- 分类问题

场景 探讨了一个回归任务——预测住房价格&#xff0c;用到了线性回归、决策树以及随机森林等各种算法。本次中我们将把注意力转向分类系统。我们曾经对MNIST进行了分类任务&#xff0c;这次我们重新回到这里&#xff0c;细致的再来一次。 开始 获取数据 Scikit-Learn提供了…

力扣爆刷第127天之动态规划五连刷(整数拆分、一和零、背包)

力扣爆刷第127天之动态规划五连刷&#xff08;整数拆分、一和零、背包&#xff09; 文章目录 力扣爆刷第127天之动态规划五连刷&#xff08;整数拆分、一和零、背包&#xff09;关于0 1 背包问题的总结01背包遍历顺序&#xff1a;完全背包遍历顺序&#xff1a; 一、343. 整数拆…

Lock-It for Mac(应用程序加密工具)

OSXBytes Lock-It for Mac是一款功能强大的应用程序加密工具&#xff0c;专为Mac用户设计。该软件具有多种功能&#xff0c;旨在保护用户的隐私和数据安全。 Lock-It for Mac v1.3.0激活版下载 首先&#xff0c;Lock-It for Mac能够完全隐藏应用程序&#xff0c;使其不易被他人…

【Pytorch】(十四)C++ 加载TorchScript 模型

文章目录 &#xff08;十四&#xff09;C 加载TorchScript 模型Step 1: 将PyTorch模型转换为TorchScriptStep 2: 将TorchScript序列化为文件Step 3: C程序中加载TorchScript模型Step 4: C程序中运行TorchScript模型 【Pytorch】&#xff08;十三&#xff09;PyTorch模型部署: T…

平衡二叉树、红黑树、B树、B+树

Tree 1、前言2、平衡二叉树和红黑树3、B树和B树3.1、B树的构建3.2、B树和B树的区别3.3、数据的存储方式 1、前言 本文侧重在理论方面对平衡二叉树、红黑树、B树和B树的各方面性能进行比较。不涉及编程方面的实现。而关于于平衡二叉树在C中的实现&#xff0c;我的上一篇文章平衡…

Nginx基本使用 反向代理与负载均衡

什么是Nginx Nginx (engine x) 是一个高性能的HTTP和反向代理web服务器。 其特点是占有内存少&#xff0c;并发能力强&#xff0c;nginx的并发能力在同类型的网页服务器中表现较好&#xff0c;而且几乎可以做到7*24不间断运行&#xff0c;即使运行数个月也不需要重新启动。 …

操作系统安全:Linux安全审计,Linux日志详解

「作者简介」&#xff1a;2022年北京冬奥会网络安全中国代表队&#xff0c;CSDN Top100&#xff0c;就职奇安信多年&#xff0c;以实战工作为基础对安全知识体系进行总结与归纳&#xff0c;著作适用于快速入门的 《网络安全自学教程》&#xff0c;内容涵盖系统安全、信息收集等…

【树莓派】yolov5 Lite,目标检测,树莓派4B,推理v5lite-e_end2end.onnx,摄像头实时目标检测

文章目录 YOLOv5 Lite: 在树莓派上轻松运行目标检测1. 环境配置2. 克隆项目3. 安装依赖项4. 下载模型权重5. 理解end2end的含义6. 示例推理7. 文件介绍8. 把文件弄到树莓派4B执行9. 进一步尝试fp16的onnx&#xff08;行不通&#xff09;10. 视频流检测 这里有大概的环境配置&am…

80个在线小游戏源码

源码简介 搭建80个在线小游戏网站源码&#xff0c;解压即可食用&#xff0c;支持在本地浏览器打开。 安装教程 纯HTML&#xff0c;直接将压缩包上传网站目录解压即可 首页截图 源码下载 80个在线小游戏源码-小8源码屋

Mac虚拟机装Windows Mac环境安装Win虚拟机教程 macbookpro安装windows虚拟机

在如今多元的数字时代&#xff0c;我们经常需要在不同的操作系统环境下进行工作和学习。而对于Mac用户来说&#xff0c;有时候需要在自己的电脑上安装Windows操作系统&#xff0c;以体验更多软件及功能&#xff0c;而在Mac安装Windows虚拟机是常用的一种操作。下面就来看看Mac虚…

前端框架EXT.NET Dotnet 3.5开发的实验室信息管理系统(LIMS)成品源码 B/S架构

前端框架EXT.NET Dotnet 3.5开发的实验室信息管理系统&#xff08;LIMS&#xff09;成品源码 B/S架构 LIMS实验室管理系统 发展历史 实验室信息管理系统(LIMS)&#xff0c;就是指通过计算机网络技术对实验的各种信息进行管理的计算机软、硬件系统。也就是将计算机网络技术与现…

新手答疑 | 零基础该怎么学习嵌入式?嵌入式Linux学习路线是什么?嵌入式开发板推荐?

很多初学者想要涉足嵌入式Linux开发领域&#xff0c;但往往在刚入门阶段&#xff0c;会因为初次接触到大量复杂的概念术语和深奥的技术文档感到压力重重&#xff0c;面对这些内容不知从何下手&#xff0c;感到十分迷茫&#xff0c;网上的内容也纷繁复杂&#xff0c;没有清晰的学…
最新文章