ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Autoencoder 基于tensorflow2.0的代码

2021-03-23 14:02:01  阅读:239  来源: 互联网

标签:code Autoencoder name outputs 代码 28 tensorflow2.0 train Dense


具体原理不讲了,网上资料相当多,但是感觉直接可以用的代码不多,所以基于各种资料实现了代码。

第一,利用autoencoder降噪,

第一部分:数据准备 

from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data() 
x_train = x_train.reshape((-1, 28 * 28)) / 255.0  # 除255的目的就是归一化 reshape的目的就是flatten()
x_test = x_test.reshape((-1, 28 * 28)) / 255.0

第二部分:构建网络结构

code_dim = 64
inputs = Input(shape=(x_train.shape[1],), name='inputs')  # 就是flatten
code = Dense(256, activation='relu', name='code1')(inputs)
code = Dense(128, activation='relu', name='code2')(code)
code = Dense(code_dim, activation='relu', name='code3')(code)

outputs = Dense(64, activation='relu', name='outputs1')(code)
outputs = Dense(128, activation='relu', name='outputs2')(outputs)
outputs = Dense(256, activation='relu', name='outputs3')(outputs)
outputs = Dense(x_train.shape[1], activation='softmax', name='decoder')(outputs)

第三模型训练

auto_encoder = Model(inputs, outputs)  # 训练模型 输入输出都是784维。
auto_encoder.compile(optimizer='adam',
                     loss='binary_crossentropy')
auto_encoder.fit(x_train, x_train, batch_size=64, epochs=5, validation_split=0.1)

第四预测结果与结果可视化

auto_encoder_result = auto_encoder.predict(x_test)

n = 5
for i in range(n):
    ax = plt.subplot(3, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    ax = plt.subplot(3, n, n + i + 1)
    plt.imshow(auto_encoder_result[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.show()

全部代码过程

from tensorflow.keras.layers import Dense, Input, Flatten
from tensorflow.keras import Model
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np

############################# 数据准备 #############################
(x_train, y_train), (x_test, y_test) = mnist.load_data()
np.random.seed(1)
np.random.shuffle(x_train)
np.random.shuffle(y_train)
np.random.shuffle(x_test)
np.random.shuffle(y_test)  # 随机打乱顺序 
x_train = x_train.reshape((-1, 28 * 28)) / 255.0
x_test = x_test.reshape((-1, 28 * 28)) / 255.0

############################# 网络结构 #############################
code_dim = 64
inputs = Input(shape=(x_train.shape[1],), name='inputs')  # 就是flatten
code = Dense(256, activation='relu', name='code1')(inputs)
code = Dense(128, activation='relu', name='code2')(code)
code = Dense(64, activation='relu', name='code3')(code)

outputs = Dense(64, activation='relu', name='outputs1')(code)
outputs = Dense(128, activation='relu', name='outputs2')(outputs)
outputs = Dense(256, activation='relu', name='outputs3')(outputs)
outputs = Dense(x_train.shape[1], activation='softmax', name='decoder')(outputs)

############################# 模型训练 #############################
auto_encoder = Model(inputs, outputs)  # 训练模型 这时候的输入、输出就是一个
auto_encoder.summary()
keras.utils.plot_model(auto_encoder, show_shapes=True)
auto_encoder.compile(optimizer='adam',
                     loss='binary_crossentropy')
auto_encoder.fit(x_train, x_train, batch_size=64, epochs=5, validation_split=0.1)
############################# 模型预测 #############################
auto_encoder_result = auto_encoder.predict(x_test)
############################# 可视化 #############################
n = 5
for i in range(n):
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    ax = plt.subplot(2, n, n + i + 1)
    plt.imshow(auto_encoder_result[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()






上述代码最常见的作用就是清洗噪声,但是一般情况下autoencoder还需要用来降维,比如我们希望得到图片中64维那个地方的数据 再把它变成一个8*8的。即实现了把原始图片28*28变成了8*8。 

第二,利用autoencoder降维 

利用autoencoder方法对数据进行降维,代码如下:

from tensorflow.keras.layers import Dense, Input, Flatten
from tensorflow.keras import Model
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
from tensorflow import keras

############################# 数据准备 #############################
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape((-1, 28 * 28)) / 255.0
x_test = x_test.reshape((-1, 28 * 28)) / 255.0

############################# 数据准备 #############################
code_dim = 64
inputs = Input(shape=(x_train.shape[1],), name='inputs')  # 就是flatten
code = Dense(256, activation='relu', name='code1')(inputs)
code = Dense(128, activation='relu', name='code2')(code)
code = Dense(64, activation='relu', name='code3')(code)

outputs = Dense(64, activation='relu', name='outputs1')(code)
outputs = Dense(128, activation='relu', name='outputs2')(outputs)
outputs = Dense(256, activation='relu', name='outputs3')(outputs)
outputs = Dense(x_train.shape[1], activation='softmax', name='decoder')(outputs)

auto_encoder = Model(inputs, outputs)  # 训练模型 这时候的输入、输出就是一个
encoder_model = Model(inputs, code)  # 建立encoder模型,这个模型预测得到就是降维后的数据
auto_encoder.compile(optimizer='adam',
                     loss='binary_crossentropy',
                     )
auto_encoder.fit(x_train, x_train, batch_size=64, epochs=5, validation_split=0.1)
encoded = encoder_model.predict(x_test)  # 784 变成了 64 降维后的结果

n = 5
for i in range(n):
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    ax = plt.subplot(2, n,  n + i + 1)
    plt.imshow(encoded[i].reshape(8, 8))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.show()

总结

上述代码全是基于全连接完成的,至于 CNN Autoencoder & LSTM Autoencoder会根据需要进行更新。

 

 

 

标签:code,Autoencoder,name,outputs,代码,28,tensorflow2.0,train,Dense
来源: https://blog.csdn.net/chwei20002005/article/details/115112208

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有