ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

MLP Attention实现

2021-05-26 12:59:52  阅读:507  来源: 互联网

标签:src dim 实现 Attention batch len lengths MLP size


MLP Attention注意力机制的实现公式为:

s(x,q)=V\, ^{T}tanh(Wx+Uq)

参考

https://github.com/pytorch/translate/blob/master/pytorch_translate/attention/mlp_attention.py

https://www.aclweb.org/anthology/N16-1174.pdf

基于PyTorch框架实现加性注意力机制

from typing import Dict, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor


def create_src_lengths_mask(
    batch_size: int, src_lengths: Tensor, max_src_len: Optional[int] = None
):
    """
    Generate boolean mask to prevent attention beyond the end of source
    Inputs:
      batch_size : int
      src_lengths : [batch_size] of sentence lengths
      max_src_len: Optionally override max_src_len for the mask
    Outputs:
      [batch_size, max_src_len]
    """
    if max_src_len is None:
        max_src_len = int(src_lengths.max())
    src_indices = torch.arange(0, max_src_len).unsqueeze(0).type_as(src_lengths)
    src_indices = src_indices.expand(batch_size, max_src_len)
    src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_src_len)
    # returns [batch_size, max_seq_len]
    return (src_indices < src_lengths).int().detach()


def masked_softmax(scores, src_lengths, src_length_masking=True):
    """Apply source length masking then softmax.
    Input and output have shape bsz x src_len"""
    if src_length_masking:
        bsz, max_src_len = scores.size()
        # compute masks
        src_mask = create_src_lengths_mask(bsz, src_lengths)
        # Fill pad positions with -inf
        scores = scores.masked_fill(src_mask == 0, -np.inf)

    # Cast to float and then back again to prevent loss explosion under fp16.
    return F.softmax(scores.float(), dim=-1).type_as(scores)

# s(x, q) = v.T * tanh (W * x + b)
class MLPAttentionNetwork(nn.Module):

    def __init__(self, hidden_dim, attention_dim, src_length_masking=True):
        super(MLPAttentionNetwork, self).__init__()

        self.hidden_dim = hidden_dim
        self.attention_dim = attention_dim
        self.src_length_masking = src_length_masking

        # W * x + b
        self.proj_w = nn.Linear(self.hidden_dim, self.attention_dim, bias=True)
        # v.T
        self.proj_v = nn.Linear(self.attention_dim, 1, bias=False)

    def forward(self, x, x_lengths):
        """
        :param x: seq_len * batch_size * hidden_dim
        :param x_lengths: batch_size
        :return: batch_size * seq_len, batch_size * hidden_dim
        """
        seq_len, batch_size, _ = x.size()
        # (seq_len * batch_size, hidden_dim)
        flat_inputs = x.view(-1, self.hidden_dim)
        # (seq_len * batch_size, attention_dim)
        mlp_x = self.proj_w(flat_inputs)
        # (batch_size, seq_len)
        att_scores = self.proj_v(mlp_x).view(seq_len, batch_size).t()
        # (seq_len, batch_size)
        normalized_masked_att_scores = masked_softmax(
            att_scores, x_lengths, self.src_length_masking
        ).t()
        # (batch_size, hidden_dim)
        attn_x = (x * normalized_masked_att_scores.unsqueeze(2)).sum(0)

        return normalized_masked_att_scores.t(), attn_x

测试代码为:

mlp = MLPAttentionNetwork(6, 4)
x = torch.rand((5, 3, 6))
x_lengths = torch.LongTensor([2, 3, 5])
att_scores, attn_x = mlp(x, x_lengths)
print(att_scores)
print(attn_x)

结果如下:

tensor([[0.5339, 0.4661, 0.0000, 0.0000, 0.0000],
        [0.3135, 0.3563, 0.3302, 0.0000, 0.0000],
        [0.2262, 0.1722, 0.2031, 0.2015, 0.1971]], grad_fn=<TBackward>)
tensor([[0.5803, 0.4982, 0.1476, 0.5926, 0.7372, 0.5238],
        [0.3763, 0.4945, 0.3840, 0.5774, 0.7962, 0.6052],
        [0.3320, 0.5883, 0.6167, 0.5233, 0.5037, 0.5494]],
       grad_fn=<SumBackward1>)

 

标签:src,dim,实现,Attention,batch,len,lengths,MLP,size
来源: https://blog.csdn.net/tszupup/article/details/117287405

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有