ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Transformer 中的 attention

2022-05-08 12:03:33  阅读:232  来源: 互联网

标签:dim Transformer nn 特征 self attention out


Transformer 中的 attention

转自Transformer中的attention,看完不懂扇我脸

大火的transformer 本质就是:

*使用attention机制的seq2seq。*

所以它的核心就是attention机制,今天就讲attention。直奔代码VIT-pytorch:

https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

中的

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

看吧!就是这么简单。今天就彻底搞懂这个东西。

先记住attention的这么几个点:

  • attention和CNN、RNN、FC、GCN等都是一个级别的东西,用来提取特征;既然是特征提取,一定有权重(W+B)存在。
  • attention的优点:可以像CNN一样并行运算 + 像RNN一样通过一层就拥有全局资讯。有一个东西也可以做到,那就是FC,但是FC有个弱点:对输入尺寸有限制,说白了不好适应可变输入数据,这对于序列无疑是非常不友好的。
  • pooling也可以实现,但是它是无参的过程。例如点云数据,就可以用pooling来处理,当然也有一些网络是pooling is all your need。
  • 可以像CNN一样并行运算 ,其实CNN运算也是通过im2col或winograd等转化为矩阵运算的。
  • RNN不能并行,所以通常它处理的数据有“时序”这个特点。既然是“时序”,那么就不是同一个时刻完成的,所以不能并行化。

综上所述: attention优点 = CNN并行+RNN全局资讯+对输入尺寸(时序长度维度上)没有限制。

如果你能创造一个拥有上面三点优点的东西出来,你也可以引领潮流。

然后回到代码,再熟悉这么几个设置:

  • batch维度:大家利用同样的权重和操作提取特征,可以理解为for循环式,相互之间没有信息交互;
  • multi head维度:同batch类似,不过是利用的不同权重和相同操作提取特征,最后concate一起使用;
  • FC层:是作用在每一个特征上,类似CNN中的1X1,可以叫“pointwise”,和序列长度没有关系;因为序列中所有的特征经过的是同一个FC。

下面看这个图,看完不懂的可以扇自己了:

attention的顺序是:

  1. 你有长度为n(序列)的序列,每个元素都是一个特征,每个特征都是一个向量;
  2. 每个向量都经过FC1,FC2,FC3获取到q,k,v三个向量(长度自己定),记住,不同特征用的是同一个FC1,FC2,FC3。可以说对于一个head,就一组FC1,FC2,FC3。
  3. 特征1的q1和所有特征的k 进行点乘,获取一串值,注意:和自己的k也进行点乘;点乘向量变标量,表示相似性。多个K可不就是一串标量。
  4. 3中的那一串值进行softmax操作,作为权重 对所有v加权求和,获得特征1输出;
  5. 其他所有的特征和特征1的操作一样,注意所有特征是一块并行计算的;
  6. 最后获取的和输入一样长度的特征序列再经过FC进行长度(特征维度)调整,也可以不要;

对了,softmax之前不要忘记 除以 qkv长度开方进行scaled,其实就是标准化操作(我觉得可以理解为各种N(BN,GN,LN等))。

就是这么简单,你学会了吗?

标签:dim,Transformer,nn,特征,self,attention,out
来源: https://www.cnblogs.com/lwp-nicol/p/16245173.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有