ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

task04:卷积情感分析

2021-09-24 20:06:35  阅读:231  来源: 互联网

标签:dim 卷积 self batch len filter 情感 task04 size


task04:卷积情感分析

  • CNN:
    • 能够从局部输入图像块中提取特征,并能将表示模块化,同时可以高效第利用数据
    • 可以用于处理时序数据,时间可以被看作一个空间维度,就像二维图像的高度和宽度
  • 那么为什么要在文本上使用卷积神经网络呢?
    • 与3x3 filter可以查看图像块的方式相同,1x2 filter 可以查看一段文本中的两个连续单词,即双字符
    • 本模型将使用多个不同大小的filter,这些filter将查看文本中的bi-grams(a 1x2 filter)、tri-grams(a 1x3 filter)and/or n-grams(a 1x n n n filter)。
    • 与使用FastText模型的方法不同,本节不再需要刻意地创建bi-gram将它们附加到句子末尾。

一、数据预处理:

  • 构建vocab并加载预训练好的此嵌入:

    MAX_VOCAB_SIZE = 25_000
    
    TEXT.build_vocab(train_data, 
                     max_size = MAX_VOCAB_SIZE, 
                     vectors = "glove.6B.100d", 
                     unk_init = torch.Tensor.normal_)
    
    LABEL.build_vocab(train_data)
    
  • 创建迭代器:

    BATCH_SIZE = 64
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
        (train_data, valid_data, test_data), 
        batch_size = BATCH_SIZE, 
        device = device)
    

二、构建模型:

  • 将一段文本中的每个单词沿着一个轴展开,向量中的元素沿着另一个维度展开。

  • 可以使用一个 [n x emb_dim] 的filter。可以完全覆盖 n n n 个words,因为它们的宽度为emb_dim 尺寸。

  • 一般情况下,filter 的宽度等于"image" 的宽度,我们得到的输出是一个向量,其元素数等于图像的高度(或词的长度)减去 filter 的高度加上一。

  • 实现:

    • 我们借助 nn.Conv2d实现卷积层
    • 之后,我们通过卷积层和池化层传递张量,在卷积层之后使用’ReLU’激活函数。池化层的另一个很好的特性是它们可以处理不同长度的句子。而卷积层的输出大小取决于输入的大小,不同的批次包含不同长度的句子。如果没有最大池层,线性层的输入将取决于输入语句的长度,为了避免这种情况,我们将所有句子修剪/填充到相同的长度,但是线性层来说,线性层的输入一直都是filter的总数。
    • 如果句子的长度小于实验设置的最大filter,那么必须将句子填充到最大filter的长度。在IMDb数据中不会存在这种情况,所以我们不必担心。
    • 最后,我们对合并之后的filter输出执行dropout操作,然后将它们通过线性层进行预测。
    class CNN(nn.Module):
        def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, 
                     dropout, pad_idx):
            
            super().__init__()
                    
            self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)
            
            self.convs = nn.ModuleList([
                                        nn.Conv2d(in_channels = 1, 
                                                  out_channels = n_filters, 
                                                  kernel_size = (fs, embedding_dim)) 
                                        for fs in filter_sizes
                                        ])
            
            self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
            
            self.dropout = nn.Dropout(dropout)
            
        def forward(self, text):
                    
            #text = [batch size, sent len]
            
            embedded = self.embedding(text)
                    
            #embedded = [batch size, sent len, emb dim]
            
            embedded = embedded.unsqueeze(1)
            
            #embedded = [batch size, 1, sent len, emb dim]
            
            conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]
                
            #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]
                    
            pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
            
            #pooled_n = [batch size, n_filters]
            
            cat = self.dropout(torch.cat(pooled, dim = 1))
    
            #cat = [batch size, n_filters * len(filter_sizes)]
                
            return self.fc(cat)
    

三、训练模型:

四、验证模型:

import spacy
nlp = spacy.load('en_core_web_sm')

def predict_sentiment(model, sentence, min_len = 5):
    model.eval()
    tokenized = [tok.text for tok in nlp.tokenizer(sentence)]
    if len(tokenized) < min_len:
        tokenized += ['<pad>'] * (min_len - len(tokenized))
    indexed = [TEXT.vocab.stoi[t] for t in tokenized]
    tensor = torch.LongTensor(indexed).to(device)
    tensor = tensor.unsqueeze(0)
    prediction = torch.sigmoid(model(tensor))
    return prediction.item()

参考资料:

DataWhale开源资料

标签:dim,卷积,self,batch,len,filter,情感,task04,size
来源: https://blog.csdn.net/maozixiang/article/details/120462087

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有