ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Tensorflow2.0:分类问题之手写数字识别(mnist数据集)

2020-03-06 18:01:59  阅读:730  来源: 互联网

标签:loss network 10 28 meter Tensorflow2.0 tf 手写 mnist


Tensorflow2.0下载与环境配置请参考:TF2.0环境配置
程序清单

import tensorflow as tf
from tensorflow.keras import layers, optimizers, datasets, Sequential, metrics  # 导入TF子库

# 1.数据集准备
(x, y), (x_val, y_val) = datasets.mnist.load_data()  # 加载数据集,返回的是两个元组,分别表示训练集和测试集
x = tf.convert_to_tensor(x, dtype=tf.float32)/255.  # 转换为张量,并缩放到0~1
y = tf.convert_to_tensor(y, dtype=tf.int32)  # 转换为张量(标签)
print(x.shape, y.shape)

train_dataset = tf.data.Dataset.from_tensor_slices((x, y))  # 构建数据集对象
train_dataset = train_dataset.batch(32).repeat(10)  # 设置批量训练的batch为32,要将训练集重复训练10遍

# 2.网络搭建
network = Sequential([
    layers.Dense(256, activation='relu'),  # 第一层
    layers.Dense(128, activation='relu'),  # 第二层
    layers.Dense(10)  # 输出层
])
network.build(input_shape=(None, 28*28))  # 输入
# network.summary()

# 3.模型训练(计算梯度,迭代更新网络参数)
optimizer = optimizers.SGD(lr=0.01)  # 声明采用批量随机梯度下降方法,学习率=0.01
acc_meter = metrics.Accuracy()
for step, (x, y) in enumerate(train_dataset):  # 一次输入batch组数据进行训练
    with tf.GradientTape() as tape:  # 构建梯度记录环境
        x = tf.reshape(x, (-1, 28*28))  # 将输入拉直,[b,28,28]->[b,784]
        out = network(x)  # 输出[b, 10]
        y_onehot = tf.one_hot(y, depth=10)  # one-hot编码
        loss = tf.square(out - y_onehot)
        loss = tf.reduce_sum(loss)/32  # 定义均方差损失函数,注意此处的32对应为batch的大小
        grads = tape.gradient(loss, network.trainable_variables)  # 计算网络中各个参数的梯度
        optimizer.apply_gradients(zip(grads, network.trainable_variables))  # 更新网络参数
        acc_meter.update_state(tf.argmax(out, axis=1), y)  # 比较预测值与标签,并计算精确度
    if step % 200 == 0:  # 每200个step,打印一次结果
        print('Step', step, ': Loss is: ', float(loss), ' Accuracy: ', acc_meter.result().numpy())
        acc_meter.reset_states()

训练结果
在这里插入图片描述
这里要注意的是:将mnist数据集中的标签y转换成one-hot编码

标签:loss,network,10,28,meter,Tensorflow2.0,tf,手写,mnist
来源: https://blog.csdn.net/wjinjie/article/details/104700834

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有