ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

03 Transformer 中的多头注意力(Multi-Head Attention)Pytorch代码实现

2022-07-27 20:03:12  阅读:201  来源: 互联网

标签:03 Transformer self mask Head value query model head


3:20 来个赞

24:43 弹幕,是否懂了

img

QKV 相乘(QKV 同源),QK 相乘得到相似度A,AV 相乘得到注意力值 Z

  1. 第一步实现一个自注意力机制

自注意力计算

def self_attention(query, key, value, dropout=None, mask=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    # mask的操作在QK之后,softmax之前
    if mask is not None:
        mask.cuda()
        scores = scores.masked_fill(mask == 0, -1e9)
    self_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        self_attn = dropout(self_attn)
    return torch.matmul(self_attn, value), self_attn

多头注意力

# PYthon/PYtorch/你看的这个模型的理论
class MultiHeadAttention(nn.Module):

    def __init__(self):
        super(MultiHeadAttention, self).__init__()


    def forward(self,  head, d_model, query, key, value, dropout=0.1,mask=None):
        """

        :param head: 头数,默认 8
        :param d_model: 输入的维度 512
        :param query: Q
        :param key: K
        :param value: V
        :param dropout:
        :param mask:
        :return:
        """
        assert (d_model % head == 0)
        self.d_k = d_model // head
        self.head = head
        self.d_model = d_model

        self.linear_query = nn.Linear(d_model, d_model)
        self.linear_key = nn.Linear(d_model, d_model)
        self.linear_value = nn.Linear(d_model, d_model)

        # 自注意力机制的 QKV 同源,线性变换

        self.linear_out = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(p=dropout)
        self.attn = None



        # if mask is not None:
        #     # 多头注意力机制的线性变换层是4维,是把query[batch, frame_num, d_model]变成[batch, -1, head, d_k]
        #     # 再1,2维交换变成[batch, head, -1, d_k], 所以mask要在第一维添加一维,与后面的self attention计算维度一样
        #     mask = mask.unsqueeze(1)

        n_batch = query.size(0)

        # 多头需要对这个 X 切分成多头

        # query==key==value
        # [b,1,512]
        # [b,8,1,64]

        # [b,32,512]
        # [b,8,32,64]

        query = self.linear_query(query).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)  # [b, 8, 32, 64]
        key = self.linear_key(key).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)  # [b, 8, 32, 64]
        value = self.linear_value(value).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)  # [b, 8, 32, 64]

        x, self.attn = self_attention(query, key, value, dropout=self.dropout, mask=mask)
        # [b,8,32,64]
        # [b,32,512]
        # 变为三维, 或者说是concat head
        x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.head * self.d_k)

        return self.linear_out(x)

标签:03,Transformer,self,mask,Head,value,query,model,head
来源: https://www.cnblogs.com/nickchen121/p/16526123.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有