ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

多次在线培训后,为什么识别率下降?

2019-11-11 04:55:15  阅读:240  来源: 互联网

标签:mnist tensorflow machine-learning python


我正在使用tensorflow对MNIST数据集进行图像识别.在每个训练时期,我随机选择了10,000张图像并进行了批量训练,在线训练数量为1.在最初的几个时期中,识别率有所提高,但是在几个时期之后,识别率开始大大下降. (在前20个时期中,识别率高达〜94%.然后,识别率从90-> 50-> 40-> 30-> 20).这是什么原因呢?

另外,批大小为1时,性能比批大小为100时差(最大识别率94%对96%).我翻阅了几篇参考文献,但对于小批量还是大批量实现更好的性能似乎有矛盾的结果.在这种情况下会是什么情况?

编辑:我还添加了训练数据集和测试数据集的识别率图.Recognition rate vs. epoch

我已附上以下代码的副本.谢谢您的帮助!

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot = True)

#parameters
n_nodes_hl1 = 500
n_nodes_hl2 = 500
n_nodes_hl3 = 500
n_classes = 10
batch_size = 1
x = tf.placeholder('float', [None, 784])
y = tf.placeholder('float')

#model of neural network
def neural_network_model(data):
    hidden_1_layer = {'weights':tf.Variable(tf.random_normal([784, n_nodes_hl1])               , name='l1_w'),
                      'biases': tf.Variable(tf.random_normal([n_nodes_hl1])                    , name='l1_b')}

    hidden_2_layer = {'weights':tf.Variable(tf.random_normal([n_nodes_hl1, n_nodes_hl2])       , name='l2_w'),
                      'biases' :tf.Variable(tf.random_normal([n_nodes_hl2])                    , name='l2_b')}

    hidden_3_layer = {'weights':tf.Variable(tf.random_normal([n_nodes_hl2, n_nodes_hl3])       , name='l3_w'),
                      'biases' :tf.Variable(tf.random_normal([n_nodes_hl3])                    , name='l3_b')}

    output_layer   = {'weights':tf.Variable(tf.random_normal([n_nodes_hl3, n_classes])     , name='lo_w'),
                      'biases' :tf.Variable(tf.random_normal([n_classes])                   , name='lo_b')}

    l1 = tf.add(tf.matmul(data,hidden_1_layer['weights']), hidden_1_layer['biases'])
    l1 = tf.nn.relu(l1) 
    l2 = tf.add(tf.matmul(l1,hidden_2_layer['weights']), hidden_2_layer['biases'])
    l2 = tf.nn.relu(l2)     
    l3 = tf.add(tf.matmul(l2,hidden_3_layer['weights']), hidden_3_layer['biases'])
    l3 = tf.nn.relu(l3)
    output = tf.matmul(l3,output_layer['weights']) + output_layer['biases']    
return output

#train neural network
def train_neural_network(x):
    prediction = neural_network_model(x)
    cost = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y))
    optimizer = tf.train.AdamOptimizer().minimize(cost)
    hm_epoches = 100
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(hm_epoches):
            epoch_loss=0
            for batch in range (10000):
                epoch_x, epoch_y=mnist.train.next_batch(batch_size)                
                _,c =sess.run([optimizer, cost], feed_dict = {x:epoch_x, y:epoch_y})
                epoch_loss += c
            correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y,1))
            accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
            print(epoch_loss)
            print('Accuracy_test:', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
            print('Accuracy_train:', accuracy.eval({x:mnist.train.images, y:mnist.train.labels}))

train_neural_network(x)

解决方法:

滴胶精度

你太合身了.这是当模型以重要特征为代价学习特定于训练数据中图像伪像的虚假特征时.任何应用程序的主要实验结果之一就是确定训练迭代的最佳次数.

例如,您训练数据中7的80%恰好在茎底部附近的右侧稍微偏斜,而4和1则没有.经过过多的训练之后,您的模型“决定”从其他数字中分辨出7的最佳方法是从这个额外的斜率开始,尽管有其他功能.结果,现在一些1和4被归类为7.

批量大小

同样,最佳批次大小是实验结果之一.通常,批大小为1太小:这会使前几个输入图像对内核或感知器训练中的早期权重产生太大影响.这是过度拟合的一个小案例:其中一项对模型有不适当的影响.但是,它足以将您的最佳结果更改2%.

您需要在批次大小与其他超参数之间取得平衡,以找到模型的“最佳位置”,最佳性能以及最短的培训时间.以我的经验,最好增加批处理的大小,直到每张图像的时间降级.我使用最多的模型(MNIST,CIFAR-10,AlexNet,GoogleNet,ResNet,VGG等)一旦达到最小批处理量,精度损失就很小.从那里开始,训练速度通常是选择最佳使用的可用RAM的批量大小的问题.

标签:mnist,tensorflow,machine-learning,python
来源: https://codeday.me/bug/20191111/2017458.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有