ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Attention Is All You Need

2021-06-08 16:34:14  阅读:262  来源: 互联网

标签:dim attention heads self Attention mask num Need


https://arxiv.org/abs/1706.03762

--------------------------------------------------------

2021-06-03

                                                             

encoder-decoder

attention:对于某个时刻的输出y,它在输入x上各个部分的注意力(理解为权重)

  self-attention:输出序列就是输入序列

  scaled dot-product attention:通过确定Q与K之间的相似程度来选择V

                                          

  除以一个缩放因子:点积得到的结果维度很大,使得结果处于softmax函数梯度很小的区域

                               

class PostionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super(PostionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False

        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=1)

        _2i = torch.arange(0, d_model, step=2, device=device).float()

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))

    def forward(self, x):
        batch, seg_len = x.size()
        return self.encoding[:seg_len, :]


def pad_mask(seq_q,seq_k):
    len_q=seq_q.size(1)
    mask=seq_k.eq(0)
    mask=mask.unsqueeze(1).expand(-1,len_q,-1)
    return mask


def sequence_mask(seq):
    batch,seq_len=seq.size()
    mask=torch.triu(torch.ones((seq_len,seq_len),dtype=torch.uint8),diagonal=1)
    mask=mask.unsqueeze(0).expand(batch,-1,-1)
    return mask


class ScaledDotProductAttention(nn.Module):
    def __init__(self, attention_dropout=0.):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=64 ** -0.5, attn_mask=None):
        attention = torch.bmm(q, k.transpose(1, 2))
        if scale:
            attention = attention * scale
        if attn_mask:
            attention = attention.masked_fill_(attn_mask, -np.inf)
        attention = self.softmax(attention)
        attention = self.dropout(attention)
        context = torch.bmm(attention, v)
        return context, attention


class MultiHeadAttention(nn.Module):
    def __init__(self, model_dim=512, num_heads=8, dropout=0.):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.dim_per_head = model_dim // num_heads

        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.dot_product_attention = ScaledDotProductAttention(dropout)
        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, query, key, value, attn_mask=None):
        residual = query

        dim_per_head = self.dim_per_head
        num_heads = self.num_heads
        batch = key.size(0)

        key = self.linear_k(key)
        value = self.linear_v(value)
        query = self.linear_q(query)

        key = key.view(batch * num_heads, -1, dim_per_head)
        value = value.view(batch * num_heads, -1, dim_per_head)
        query = query.view(batch * num_heads, -1, dim_per_head)

        if attn_mask:
            attn_mask = attn_mask.repeat(num_heads, 1, 1)

        scale = (key.size(-1) // num_heads) ** -0.5

        context, attention = self.dot_product_attention(query, key, value, scale, attn_mask)

        context = context.view(batch, -1, dim_per_head * num_heads)

        output = self.linear_final(context)
        output = self.dropout(output)
        output = self.layer_norm(residual + output)

        return output, attention

 

标签:dim,attention,heads,self,Attention,mask,num,Need
来源: https://www.cnblogs.com/shuimobanchengyan/p/14846640.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有