ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Tensorflow中循环神经网络及其Wrappers

2019-02-15 16:48:26  阅读:507  来源: 互联网

标签:Wrappers rnn state batch cell 神经网络 tf Tensorflow size


  • tf.nn.rnn_cell.LSTMCell

    • 又名:tf.nn.rnn_cell.BasicLSTMCelltf.contrib.rnn.LSTMCell

    • 参见: tf.nn.rnn_cell.LSTMCell

    • 输出:

      • output:LSTM单元输出h,与LSTM cell state的区别在于该输出又经过激活以及和一个sigmoid函数输出相乘。shape: [batch_size,num_units]
      • new_state:当前时间步上的LSTM cell stateLSTM output,LSTMStateTuple:(c,h),其中,h与上述的output张量相同。shape: ([batch_size,num_units],[batch_size,num_units])
    • 示例:

      batch_size=10
      embedding_dim=300
      inputs=tf.Variable(tf.random_normal([batch_size,embedding_dim]))
      previous_state=(tf.Variable(tf.random_normal([batch_size,128])),tf.Variable(tf.random_normal([batch_size,128])))
      lstmcell=tf.nn.rnn_cell.LSTMCell(128)
      outputs,(c_state,h_state)=lstmcell(inputs,previous_state)

      输出:

      (<tf.Tensor 'lstm_cell/mul_2:0' shape=(10, 128) dtype=float32>,
       LSTMStateTuple(c=<tf.Tensor 'lstm_cell/add_1:0' shape=(10, 128) dtype=float32>, h=<tf.Tensor 'lstm_cell/mul_2:0' shape=(10, 128) dtype=float32>))
  • tf.nn.rnn_cell.MultiRNNCell

    • 参见:tf.nn.rnn_cell.MultiRNNCell

    • 输出:

      • outputs: 最顶层cell的最后一个时间步的输出。shape:[batch_size,cell.output_size]
      • states:每一层的state,M层LSTM则输出M个LSTMStateTuple组成的Tuple。
    • 示例:

      batch_size=10
      inputs=tf.Variable(tf.random_normal([batch_size,128]))
      previous_state0=(tf.random_normal([batch_size,100]),tf.random_normal([batch_size,100]))
      previous_state1=(tf.random_normal([batch_size,200]),tf.random_normal([batch_size,200]))
      previous_state2=(tf.random_normal([batch_size,300]),tf.random_normal([batch_size,300]))
      num_units=[100,200,300]
      cells=[tf.nn.rnn_cell.LSTMCell(num_unit) for num_unit in num_units]
      mul_cells=tf.nn.rnn_cell.MultiRNNCell(cells)
      outputs,states=mul_cells(inputs,(previous_state0,previous_state1,previous_state2))

      输出:

      outputs:
      <tf.Tensor 'multi_rnn_cell_1/cell_2/lstm_cell/mul_2:0' shape=(10, 300) dtype=float32>
      states:
      
      Out[29]:
      (LSTMStateTuple(c=<tf.Tensor 'multi_rnn_cell_1/cell_0/lstm_cell/add_1:0' shape=(10, 100) dtype=float32>, h=<tf.Tensor 'multi_rnn_cell_1/cell_0/lstm_cell/mul_2:0' shape=(10, 100) dtype=float32>),
       LSTMStateTuple(c=<tf.Tensor 'multi_rnn_cell_1/cell_1/lstm_cell/add_1:0' shape=(10, 200) dtype=float32>, h=<tf.Tensor 'multi_rnn_cell_1/cell_1/lstm_cell/mul_2:0' shape=(10, 200) dtype=float32>),
       LSTMStateTuple(c=<tf.Tensor 'multi_rnn_cell_1/cell_2/lstm_cell/add_1:0' shape=(10, 300) dtype=float32>, h=<tf.Tensor 'multi_rnn_cell_1/cell_2/lstm_cell/mul_2:0' shape=(10, 300) dtype=float32>))
  • tf.nn.dynamic_rnn

    • 参见:tf.nn.dynamic_rnn

    • 输出:

      • outputs: 每个时间步上的LSTM输出;若有多层LSTM,则为每一个时间步上最顶层的LSTM的输出。shape: [batch_size,max_time,cell.output_size]
      • state:最后一个时间步的状态,该状态使用LSTMStateTuple结构输出;若有M层LSTM,则输出M个LSTMStateTuple。单层LSTM输出:[batch_size,cell.output_size];M层LSTM输出:M个LSTMStateTuple组成的Tuple,这也即是说:outputs[:,-1,:]==state[-1,:,:]。
    • 示例:

      batch_size=10
      max_time=20
      data=tf.Variable(tf.random_normal([batch_size,max_time,128]))
      # create a BasicRNNCell
      rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
      
      # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
      
      # defining initial state
      initial_state = rnn_cell.zero_state(batch_size,dtype=tf.float32)
      
      # 'state' is a tensor of shape [batch_size, cell_state_size]
      outputs, state = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=data,
                                         initial_state=initial_state,
                                         dtype=tf.float32)

      输出:

      outpus:
      <tf.Tensor 'rnn_2/transpose_1:0' shape=(10, 20, 128) dtype=float32>
      state:
      <tf.Tensor 'rnn_2/while/Exit_3:0' shape=(10, 128) dtype=float32>
      batch_size=10
      max_time=20
      data=tf.Variable(tf.random_normal([batch_size,max_time,128]))
      # create 2 LSTMCells
      rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
      
      # create a RNN cell composed sequentially of a number of RNNCells
      multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
      
      # 'outputs' is a tensor of shape [batch_size, max_time, 256]
      # 'state' is a N-tuple where N is the number of LSTMCells containing a
      # tf.contrib.rnn.LSTMStateTuple for each cell
      outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                         inputs=data,
                                         dtype=tf.float32)
      outputs:
      <tf.Tensor 'rnn_1/transpose_1:0' shape=(10, 20, 256) dtype=float32>
      state:
      
      Out[20]:
      (LSTMStateTuple(c=<tf.Tensor 'rnn_1/while/Exit_3:0' shape=(10, 128) dtype=float32>, h=<tf.Tensor 'rnn_1/while/Exit_4:0' shape=(10, 128) dtype=float32>),
       LSTMStateTuple(c=<tf.Tensor 'rnn_1/while/Exit_5:0' shape=(10, 256) dtype=float32>, h=<tf.Tensor 'rnn_1/while/Exit_6:0' shape=(10, 256) dtype=float32>))
  • tf.nn.bidirectional_dynamic_rnn

    • 参见:tf.nn.bidirectional_dynamic_rnn

    • 输出:

      • outputs:(output_fw,output_bw):前向cell+后向cell

        其中,output_fw、output_bw:[batch_size,max_time,cell.output_size]

      • state:(output_state_fw,output_state_bw):包含前向和后向隐状态组成的元组

        其中,output_state_fw、output_state_bw均为LSTMStateTuple。LSTMStateTuple:(c,h),分别为cell_state,hidden_output

  • tf.contrib.seq2seq.dynamic_decode

    • 输出:
      • final_outputs,包含rnn_output和sample_id,分别可用final_output.rnn_output和final_outputs.sample_id获取到。
      • final_state,可以从最后一个解码器状态获取alignments,alignments = tf.transpose(final_decoder_state.alignment_history.stack(), [1, 2, 0])
      • final_sequence_lengths

标签:Wrappers,rnn,state,batch,cell,神经网络,tf,Tensorflow,size
来源: https://www.cnblogs.com/mengnan/p/10384484.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有