ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

LSTM预测sin(X)

2019-08-31 19:00:51  阅读:251  来源: 互联网

标签:预测 predictions ds test train tf np LSTM sin


1.模型

多层LSTM

2.用到的函数

tf.nn.rnn_cell.BasicLSTMCell(num_units)

num_units这个参数的大小就是LSTM输出结果的维度。例如num_units=128, 那么LSTM网络最后输出就是一个128维的向量。http://www.mtcnn.com/?p=529

tf.nn.dynamic_rnn

https://blog.csdn.net/junjun150013652/article/details/81331448

tf.contrib.layers.fully_connected(inputs,num_outputs)

增加一个全连接层
自动初始化w和b
激活函数默认为relu函数
输出个数由num_outputs指定

tf.losses.mean_squared_error

https://www.w3cschool.cn/tensorflow_python/tensorflow_python-zkxr2x87.html

3.代码

import numpy as np
import tensorflow as tf
import matplotlib as mpl
mpl.use('Agg')
from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")

hidden_size = 30
num_layers = 2

timesteps = 10
training_steps = 1000
batch_size = 32

training_examples = 10000
testing_examples  =1000
sample_gap = 0.01

def generate_data(seq) :
    x = []
    y = []
    for i in range(len(seq) - timesteps) :
        x.append([seq[i: i + timesteps]])
        y.append([seq[i + timesteps]])
        
    return np.array(x, dtype=np.float32), np.array(y, dtype=np.float32)

def lstm_model(x, y, is_training) :
    cell = tf.nn.rnn_cell.MultiRNNCell(
            [tf.nn.rnn_cell.BasicLSTMCell(hidden_size)
            for _ in range(num_layers)])
    
    outputs, _ = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32)
    
    # 下标为-1表示取出列表的最后一行数据值
    output = outputs[:, -1, :]
    
    predictions = tf.contrib.layers.fully_connected(
                    output, 1, activation_fn=None)
    
    if not is_training :
        return predictions, None, None
    
    loss = tf.losses.mean_squared_error(labels=y, predictions=predictions)
    
    trian_op = tf.contrib.layers.optimize_loss(
                loss, tf.train.get_global_step(),
                optimizer='Adagrad', learning_rate=0.1)
    
    return predictions, loss, trian_op

def train(sess, train_x, train_y) :
    ds = tf.data.Dataset.from_tensor_slices((train_x, train_y))
    ds = ds.repeat().shuffle(1000).batch(batch_size)
    x, y = ds.make_one_shot_iterator().get_next()
    
    with tf.variable_scope("model", reuse=tf.AUTO_REUSE) :
        predictions, loss, train_op = lstm_model(x, y, True)
        
    sess.run(tf.global_variables_initializer())
    
    for i in range (training_steps) :
        _, l = sess.run([train_op, loss])
        if i % 100 == 0 :
            print("train step: " + str(i) + ", loss: " + str(l))
            
def run_eval(sess, test_x, test_y) :
    ds = tf.data.Dataset.from_tensor_slices((test_x, test_y))
    ds = ds.batch(1)
    x, y = ds.make_one_shot_iterator().get_next()
    
    with tf.variable_scope("model", reuse=True) :
        predection, _, _ = lstm_model(x, [0,0], False)
        
    predictions = []
    labels = []
    for i in range(testing_examples) :
        p, l = sess.run([predection, y])
        predictions.append(p)
        labels.append(l)
        
    predictions = np.array(predictions).squeeze()
    labels = np.array(labels).squeeze()
    rmse = np.sqrt(((predictions - labels) ** 2).mean(axis=0))
    print("Mean square error is: %f" % rmse)
    
    %matplotlib inline
    plt.figure()
    plt.plot(predictions, label='predictions')
    plt.plot(labels,  label='real_sin')
    plt.legend()
    plt.show()
    
test_start = (training_examples + timesteps) * sample_gap
test_end = test_start + (testing_examples + timesteps) * sample_gap
test_x, test_y = generate_data(np.sin(np.linspace(
                    0, test_start, training_examples+timesteps, dtype=np.float32)))

with tf.Session() as sess :
    train(sess, train_x, train_y)
    run_eval(sess, test_x, test_y)
    
  

 

 

 

 

 

 

 

标签:预测,predictions,ds,test,train,tf,np,LSTM,sin
来源: https://blog.csdn.net/WukongAKK/article/details/100176030

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有