标签:plt keras 人工神经网络 --- tf train test TensorFlow model
这种格式也可以,但不清晰
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
#加载数据
mnist = tf.keras.datasets.mnist
(train_x,train_y),(test_x,test_y) = mnist.load_data()
x_train ,x_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)
y_train,y_test = tf.cast (train_y,tf.int16),tf.cast(test_y,tf.int16)
#建立模型
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dense(10,activation='softmax'))
model.summary()
#配置模型的训练方法
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])
#训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5,validation_split=0.2)
#评估模型
model.evaluate(x_test,y_test,verbose=2 )
#使用模型
plt.axis('off')
plt.imshow(test_x[0],cmap='gray')
plt.show()
print(y_test[0])
model.predict([[x_test[0]]])
print(np.argmax(model.predict([[x_test[0]]])))
标签:plt,keras,人工神经网络,---,tf,train,test,TensorFlow,model 来源: https://blog.csdn.net/mabaiteng/article/details/121925811
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。