ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

用于Transformer的6种注意力的数学原理和代码实现

2022-03-31 10:34:08  阅读:274  来源: 互联网

标签:Transformer scale self attention mask Attention 数学原理 注意力


Transformer 的出色表现让注意力机制出现在深度学习的各处。本文整理了深度学习中最常用的6种注意力机制的数学原理和代码实现。

1、Full Attention

2017的《Attention is All You Need》中的编码器-解码器结构实现中提出。它结构并不复杂,所以不难理解。

上图 1.左侧显示了 Scaled Dot-Product Attention 的机制。当我们有多个注意力时,我们称之为多头注意力(右),这也是最常见的注意力的形式公式如下:

公式1

这里Q(Query)、K(Key)和V(values)被认为是它的输入,dₖ(输入维度)被用来降低复杂度和计算成本。这个公式可以说是深度学习中注意力机制发展的开端。下面我们看一下它的代码:

  1. class FullAttention(nn.Module):
  2. def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
  3. super(FullAttention, self).__init__()
  4. self.scale = scale
  5. self.mask_flag = mask_flag
  6. self.output_attention = output_attention
  7. self.dropout = nn.Dropout(attention_dropout)
  8. def forward(self, queries, keys, values, attn_mask):
  9. B, L, H, E = queries.shape
  10. _, S, _, D = values.shape
  11. scale = self.scale or 1. / sqrt(E)
  12. scores = torch.einsum("blhe,bshe->bhls", queries, keys)
  13. if self.mask_flag:
  14. if attn_mask is None:
  15. attn_mask = TriangularCausalMask(B, L, device=queries.device)
  16. scores.masked_fill_(attn_mask.mask, -np.inf)
  17. A = self.dropout(torch.softmax(scale * scores, dim=-1))
  18. V = torch.einsum("bhls,bshd->blhd", A, values)
  19. if self.output_attention:
  20. return (V.contiguous(), A)
  21. else:
  22. return (V.contiguous(), None)

2、ProbSparse Attention

借助“Transformer Dissection: A Unified Understanding of Transformer's Attention via the lens of Kernel”中的信息我们可以将公式修改为下面的公式2。第i个query的attention就被定义为一个概率形式的核平滑方法(kernel smoother):

公式2

从公式 2,我们可以定义第 i 个查询的稀疏度测量如下:

完整文章:

https://www.overfit.cn/post/739299d8be4e4ddc8f5804b37c6c82ad

标签:Transformer,scale,self,attention,mask,Attention,数学原理,注意力
来源: https://www.cnblogs.com/deephub/p/16080521.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有