ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

seq2seq

2020-01-14 14:07:41  阅读:359  来源: 互联网

标签:inputs RNN seq2seq Encoder decoder tf encoder


算法框架

seq2seq简单来说就一个编码,再解码的过程。seq2seq 模型就像一个翻译模型,输入是一个序列(比如一个英文句子),输出也是一个序列(比如该英文句子所对应的法文翻译)。这种结构最重要的地方在于输入序列和输出序列的长度是可变的。
seq2seq原理图
在这里插入图片描述

算法原理:

1、由编码和解吗两个部分组成seq2seq模型的整体框架
2、编码阶段的RNN序列的最后一个状态作为解吗RNN的初始状态
3、解码阶段,从符号开始,每个时刻的输出将会作为下一个时刻的输入,以此类推,直到 DecoderCell 某个时刻预测输出特殊符号 结束。

原理解析:

文章[1]中提出 Encoder-Decoder 这种结构。
其中 Encoder 部分应该是非常容易理解的,就是一个RNNCell(RNN ,GRU,LSTM 等) 结构。每个 timestep, 我们向 Encoder 中输入一个字/词(一般是表示这个字/词的一个实数向量),直到我们输入这个句子的最后一个字/词 XTX_TXT​,然后输出整个句子的语义向量 c(一般情况下, c=h(XT)c=h_(X_T)c=h(​XT​) ,XTX_TXT​ 是最后一个输入)。因为 RNN 的特点就是把前面每一步的输入信息都考虑进来了,所以理论上这个 cc 就能够把整个句子的信息都包含了,我们可以把 cc 当成这个句子的一个语义表示,也就是一个句向量。在 Decoder 中,我们根据 Encoder 得到的句向量 cc, 一步一步地把蕴含在其中的信息分析出来。
在这里插入图片描述
代码解析:
encoder阶段代码

#1.首先定义编码的输入
encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
#2.定义embed矩阵
embeddings = tf.Variable(tf.random_uniform([vocab_size, input_embedding_size], -1.0, 1.0), dtype=tf.float32)
# 3.对输入进行embed
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
# 4.定义RNN结构
encoder_cell = tf.nn.rnn_cell.LSTMCell(encoder_hidden_units)
# 5.执行RNN,得到输出和状态,最终状态作为解码RNN的初始状态
encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(encoder_cell, encoder_inputs_embedded,dtype=tf.float32, time_major=True)

decoder阶段代码

# 1.首先定义解码的输入和标签,
decoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_inputs')
decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets')
# 2.编码解码使用同一embed矩阵,对输入进行decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, decoder_inputs)
# 3.定义解码RNN结构,可以和编码不同
decoder_cell = tf.nn.rnn_cell.LSTMCell(decoder_hidden_units)
#4.执行RNN,得到输出和状态
decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
    decoder_cell, decoder_inputs_embedded,
    initial_state=encoder_final_state, 
    dtype=tf.float32, time_major=True, scope="plain_decoder")
#5.定义全连接层,得到softmax
decoder_logits = tf.layers.dense(decoder_outputs, vocab_size)
decoder_prediction = tf.argmax(decoder_logits, 2) 

训练过程 优化:loss & optimizer

stepwise_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=decoder_targets, logits=decoder_logits)
loss = tf.reduce_mean(stepwise_cross_entropy)
# 梯度裁剪        
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
gvs = optimizer.compute_gradients(loss)
capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs]
train_op = optimizer.apply_gradients(capped_gvs)

注意力机制

论文[3] Bahdanau et al. 是在 Encoder 和 Decoder 的基础上提出了注意力机制。
在论文[1] 的 decoder 中,每次预测下一个词都会用到中间语义 cc,而这个 cc 呢,主要就是最后一个时刻的隐藏状态。在论文[3] 中,我们可以看到在Decoder进行预测的时候,Encoder 中每个时刻的隐藏状态都被利用上了。这样子,Encoder 就能利用多个语义信息(隐藏状态)来表达整个句子的信息了。

加attention原因

如果拿机器翻译来解释这个分心模型的Encoder-Decoder框架更好理解,比如输入的是英文句子:Tom chase Jerry,Encoder-Decoder框架逐步生成中文单词:“汤姆”,“追逐”,“杰瑞”。在翻译“杰瑞”这个中文单词的时候,分心模型里面的每个英文单词对于翻译目标单词“杰瑞”贡献是相同的,很明显这里不太合理,显然“Jerry”对于翻译成“杰瑞”更重要,但是分心模型是无法体现这一点的,这就是为何说它没有引入注意力的原因。
引入attention model之后:每个英文单词的概率代表了翻译当前单词“杰瑞”时,注意力分配模型分配给不同英文单词的注意力大小。这对于正确翻译目标语单词肯定是有帮助的,因为引入了新的信息。
这意味着在生成每个单词Yi的时候,原先都是相同的中间语义表示C会替换成根据当前生成单词而不断变化的Ci。理解AM模型的关键就是这里,即由固定的中间语义表示C换成了根据当前输出单词来调整成加入注意力模型的变化的Ci。增加了AM模型的Encoder-Decoder框架理解起来如图所示。
在这里插入图片描述

参靠文献

[1] Cho et al., 2014 . Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation
[2]https://blog.csdn.net/Jerr__y/article/details/53749693
[3] Bahdanau et al., 2014. Neural Machine Translation by Jointly Learning to Align and Translate
https://blog.csdn.net/malefactor/article/details/50550211
[4] Jean et. al., 2014. On Using Very Large Target Vocabulary for Neural Machine Translation

地平线的光 发布了11 篇原创文章 · 获赞 1 · 访问量 475 私信 关注

标签:inputs,RNN,seq2seq,Encoder,decoder,tf,encoder
来源: https://blog.csdn.net/u010443559/article/details/103968480

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有