ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

生成对抗网络(GAN)系列(一)

2021-09-17 22:02:38  阅读:476  来源: 互联网

标签:dim self boldsymbol 生成 GAN data 对抗 size


生成式模型的作用

密度估计

给定一组数据\(D=\left \{ x^{n} \right \}^{N}_{n=1}\),假设它们都是独立地从相同的概率密度函数为\(p_{r}(x)\)的未知分布中产生的。密度估计是根据数据集\(D\)来估计其概率密度函数\(p_{\theta}(x)\)。
在机器学习中,密度估计是一类无监督学习问题。比如在手写体数字图像的密度估计问题中,我们将图像表示为一个随机变量\(X\),其中每一维都表示一个像素值。假设手写体数字图像都服从一个未知的分布\(p_{r}{x}\),希望通过一些观测样本来估计其分布。但是手写体数字图像中不同像素之间存在复杂的依赖关系,很难用一个明确的图模型来描述其依赖关系,所以直接建模\(p_{r}{x}\),比较困难。因此,我们通过引入隐变量\(z\)来简化模型,这样密度估计问题可以转化为估计变量(x, z)的两个局部条件概率\(p_{\theta}(z)\)和\(p_{\theta}(x|z)\)。一般为了简化模型,假设隐变量\(z\)的先验分布为标准高斯分布\(N(0, I)\)。隐变量\(z\)的每一维度之间都是独立的,密度估计的重点是估计条件分布\(p(x|z; \theta)\)。
如果要建模隐含变量的分布,就需要用EM算法来进行密度估计,而在EM算法中,需要估计条件分布\(p(x|z; \theta)\)以及后验概率分布\(p(z|x; \theta)\)。当这两个分布比较复杂时,就可以利用神经网络来建模(如变分自编码器)。

生成样本

在知道\(p_{\theta}(z)\)和得到\(p_{\theta}(x|z)\)之后就可以生成新的数据:

  • 从隐变量的先验分布\(p_{\theta}(z)\)中采样,得到样本\(z\)。
  • 根据条件概率分布\(p_{\theta}(x|z)\)进行采样,得到新的样本\(x\)

生成对抗网络

本文的重点是生成对抗网络(GAN)。与一般的生成式模型(如VAE、DQN)不同,GAN并不直接建模\(p(x)\),而是直接通过一个神经网络学习从隐变量\(z\)到数据\(x\)的映射,称为生成器;然后将生成的样本交给判别网络判断是否是真实的样本。可以看出,生成网络和判别网络的训练是彼此依存、交替进行的。

生成对抗网络流程图

判别网络

判别网络\(D(\boldsymbol x;\phi )\)的目标是区分出一个样本\(\boldsymbol x\)是来自于真实分布\(p_{r}(\boldsymbol x)\)还是来自于生成模型\(p_{\theta}(\boldsymbol x)\)。由此可见,判别网络实际上是一个二分类的分类器。用标签\(y=1\)来表示样本来自真实分布,\(y=0\)表示样本来自生成模型,判别网络\(D(\boldsymbol x;\phi )\)的输出为\(\boldsymbol x\)属于真实数据分布的概率,即:

\[p(y=1 | x) = D(\boldsymbol x;\phi ) \]

样本来自生成模型的概率为:

\[p(y=0 | x) = 1 - D(\boldsymbol x;\phi ) \]

给定一个样本\((x,y),y= \left \{ 1,0 \right \}\),表示其来自于\(p_{r}(\boldsymbol x)\)还是\(p_{\theta}(\boldsymbol x)\),判别网络的目标函数为最小化交叉熵,即:

\[\mathop{min}_{\phi }-\left ( \mathbb{E}_{x}\left [ ylogp(y=1| \boldsymbol x) + (1-y)log p(y=0| \boldsymbol x)\right ] \right ) \]

假设分布\(p(\boldsymbol x)\)是由分布\(p_{r}(\boldsymbol x)\)和分布\(p_{\theta}(\boldsymbol x)\)等比例混合而成,即\(p(\boldsymbol x) = \frac{1}{2} * \left (p_{r}(\boldsymbol x) + p_{\theta}(\boldsymbol x) \right )\),则上式等价于:

\[\mathop{max}_{\phi } \mathbb{E}_{\boldsymbol x \sim p_{r}(\boldsymbol x)}\left [ logD(\boldsymbol x ;\phi ) \right ] + \mathbb{E}_{\boldsymbol x ^{'} \sim p_{\theta}(\boldsymbol x ^{'})}\left [ log\left ( 1 - D(\boldsymbol x ^{'} ;\phi ) \right ) \right ] \]

\[=\mathop{max}_{\phi } \mathbb{E}_{\boldsymbol x \sim p_{r}(\boldsymbol x)}\left [ logD(\boldsymbol x ;\phi ) \right ] + \mathbb{E}_{\boldsymbol z \sim p(\boldsymbol z )}\left [ log\left ( 1 - D(G(\boldsymbol z ;\theta ) ;\phi ) \right ) \right ] \]

其中\(\theta\)和\(\phi\)分别是生成网络和判别网络的参数。

生成网络

生成网络的目标刚好和判别网络相反,即让判别网络将自己生成的样本判别为真是样本。

\[\mathop{max}_{\theta } \mathbb{E}_{\boldsymbol z \sim p(\boldsymbol z)}\left [ logD \left (G (\boldsymbol z; \theta ) ;\phi \right ) \right ] \]

\[=\mathop{min}_{\theta } \mathbb{E}_{\boldsymbol z \sim p(\boldsymbol z)}\left [ log(1 - D \left (G (\boldsymbol z; \theta ) ;\phi \right )) \right ] \]

两个目标函数是等价的,但一般使用前者,因为其梯度性质更好。

训练

和单目标的优化任务相比,生成对抗网络的两个网络的优化目标刚好相反。因此生成对抗网络的训练比较难,往往不太稳定. 一般情况下,需要平衡两个网络的能力。对于判别网络来说,一开始的判别能力不能太强,否则难以提升生成网络的能力。但是,判别网络的判别能力也不能太弱,否则针对它训练的生成网络也不会太好。 在训练时需要使用一些技巧,使得在每次迭代中,判别网络比生成网络的能力强一些,但又不能强太多。具体做法是,判别网络更新\(K\)次,生成网络更新1次。

生成对抗网络训练过程

代码实现

hyperparam.py文件
超参数配置模块

import argparse


class HyperParam:
    def __init__(self):
        self.parse = argparse.ArgumentParser()
        self.parse.add_argument("--latent_dim", type=int, default=5)  # 隐含变量的维度
        self.parse.add_argument("--data_dim", type=int, default=10)  # 观测变量的维度
        self.parse.add_argument("--data_size", type=int, default=10000)  # 样本数
        self.parse.add_argument("--g_lr", type=float, default=0.001)
        self.parse.add_argument("--d_lr", type=float, default=0.001)
        self.parse.add_argument("--epochs", type=int, default=300)
        self.parse.add_argument("--K", type=int, default=5)
        self.parse.add_argument("--sample_size", type=int, default=128)
        self.parse.add_argument("--batch_size", type=int, default=128)

gan.py文件
GAN的实现部分

import numpy as np
import torch
from hyperparam import HyperParam
import torch.nn as nn
import torch.utils.data as Data
import matplotlib.pyplot as plt

np.random.seed(1000)
torch.manual_seed(1000)


def get_real_data(data_dim, data_size, batch_size):
    base = np.linspace(-1, 1, data_dim)
    a = np.random.uniform(8, 15, data_size).reshape(-1, 1)
    c = np.random.uniform(0.5, 10, data_size).reshape(-1, 1)

    # 构造真实数据
    X = a * np.power(base, 2) + c
    X = torch.from_numpy(X).type(torch.float32)
    data_set = Data.TensorDataset(X)
    data_loader = Data.DataLoader(dataset=data_set,
                                  batch_size=batch_size)

    return base, data_loader


class GAN(nn.Module):
    def __init__(self, latent_dim, data_dim, K, sample_size):
        super().__init__()
        self.latent_dim = latent_dim
        self.data_dim = data_dim
        self.K = K
        self.sample_size = sample_size

        self.g = self._generator()
        self.d = self._discriminator()

        self.g_optimizer = torch.optim.Adam(self.g.parameters(), lr=0.001)
        self.d_optimizer = torch.optim.Adam(self.d.parameters(), lr=0.001)

    def _generator(self):
        model = nn.Sequential(
            nn.Linear(self.latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, self.data_dim)
        )
        return model

    def _discriminator(self):
        model = nn.Sequential(
            nn.Linear(self.data_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        return model

    def d_loss_fn(self, pred_data_result, true_data_result):
        return -torch.mean(torch.log(true_data_result) + torch.log(1 - pred_data_result))

    def g_loss_fn(self, pred_data_result):
        return -torch.mean(torch.log(pred_data_result))

    def train_d(self, true_data):
        sample_size = true_data.shape[0]
        for i in range(self.K):
            # 采样
            sample = torch.rand(sample_size, self.latent_dim)
            # 生成
            fake_data = self.g(sample)
            # 生成数据的判定结果
            fake_data_result = self.d(fake_data)

            # 真实数据的判定结果
            true_data_result = self.d(true_data)

            loss = self.d_loss_fn(fake_data_result, true_data_result)
            self.d_optimizer.zero_grad()
            loss.backward()
            self.d_optimizer.step()

    def train_g(self):
        # 采样
        sample = torch.rand(self.sample_size, self.latent_dim)
        # 生成
        fake_data = self.g(sample)
        # 生成数据的判定结果
        fake_data_result = self.d(fake_data)

        loss = self.g_loss_fn(fake_data_result)
        self.g_optimizer.zero_grad()
        loss.backward()
        self.g_optimizer.step()

    def step(self, true_data):
        self.train_d(true_data)  # 先训练判别器
        self.train_g()  # 再训练生成器


def train(epochs, latent_dim, data_dim, K, sample_size, data_loader, base):
    print('正在训练......')
    model = GAN(latent_dim, data_dim, K, sample_size)

    plt.ion()
    for epoch in range(epochs):
        for true_data in data_loader:
            model.step(true_data[0])  # [128, 15]
        if (epoch + 1) % 50 == 0:
            print('epoch: [{}/{}]'.format(epoch + 1, epochs))
            # 采样
            sample = torch.rand(1, latent_dim)
            # 生成
            fake_data = model.g(sample)
            plt.cla()
            plt.plot(base, fake_data.data.numpy().flatten())
            plt.show()
            plt.pause(0.1)
    plt.ioff()
    plt.show()

    torch.save(model.state_dict(), 'gan_param.pkl')
    print('模型保存成功')


if __name__ == "__main__":
    instance = HyperParam()
    hp = instance.parse.parse_args()
    epochs = hp.epochs
    latent_dim = hp.latent_dim
    data_dim = hp.data_dim
    K = hp.K
    sample_size = hp.sample_size
    data_size = hp.data_size
    batch_size = hp.batch_size

    base, data_loader = get_real_data(data_dim, data_size, batch_size)
    train(epochs, latent_dim, data_dim, K, sample_size, data_loader, base)

运行结果及分析

运行结果

从图中可以看出,从左到右,生成模型绘制二次曲线的能力越来越强了,训练500个epoch之后,生成的图形比较接近真实的二次曲线。

结果分析

实际运行程序时会发现,GAN的生成效果对激活函数和超参数的依赖非常大,特别是超参数K(训练K次判别器之后再训练一次生成器)的取值,如果K的取值稍微不合理,那么会直接导致生成器的损失太大,无法继续优化下去。此外,GAN需要足够的多的样本学习,特别是如果隐变量维度较多的话,需要更多的样本才有可能学得比较好的模型;模型训练过程中存在明显的震荡现象。

GAN的优缺点分析

优点

  • GAN是一种生成式模型,相比较其他生成模型(玻尔兹曼机和GSNs)只用到了反向传播,而不需要复杂的马尔科夫链。
  • 相比其他所有模型, GAN可以产生更加清晰,真实的样本。
  • GAN采用的是一种无监督的学习方式训练,可以被广泛用在无监督学习和半监督学习领域。
  • 相比于变分自编码器, GANs没有引入任何决定性偏置( deterministic bias),变分方法引入决定性偏置,因为他们优化对数似然的下界,而不是似然度本身,这看起来导致了VAEs生成的实例比GANs更模糊。
  • 相比VAE, GANs没有变分下界,如果鉴别器训练良好,那么生成器可以完美的学习到训练样本的分布。换句话说,GANs是渐进一致的,但是VAE是有偏差的。

缺点

  • GAN不适合处理离散形式的数据,比如文本。
  • GAN存在训练不稳定、梯度消失、模式崩溃的问题(目前已解决)

关于GAN的一些问题

模式崩溃的原因

一般出现在GAN训练不稳定的时候,具体表现为生成出来的结果非常差,但是即使加长训练时间后也无法得到很好的改善。
具体原因可以解释如下:GAN采用的是对抗训练的方式,G的梯度更新来自D,所以G生成的好不好,得看D怎么说。具体就是G生成一个样本,交给D去评判,D会输出生成的假样本是真样本的概率(0-1),相当于告诉G生成的样本有多大的真实性,G就会根据这个反馈不断改善自己,提高D输出的概率值。但是如果某一次G生成的样本可能并不是很真实,但是D给出了正确的评价,或者是G生成的结果中一些特征得到了D的认可,这时候G就会认为我输出的正确的,那么接下来我就这样输出肯定D还会给出比较高的评价,实际上G生成的并不怎么样,但是他们两个就这样自我欺骗下去了,导致最终生成结果缺失一些信息,特征不全。

为什么优化器不常用SGD

  • SGD容易震荡,容易使GAN训练不稳定。
  • GAN的目的是在高维非凸的参数空间中找到纳什均衡点,GAN的纳什均衡点是一个鞍点,但是SGD只会找到局部极小值,因为SGD解决的是一个寻找最小值的问题,GAN是一个博弈问题。

为什么不适合处理文本数据

  • 文本数据相比较图片数据来说是离散的,因为对于文本来说,通常需要将一个词映射为一个高维的向量,最终预测的输出是一个one-hot向量,假设softmax的输出是(0.2, 0.3, 0.1,0.2,0.15,0.05)那么变为onehot是(0,1,0,0,0,0),如果softmax输出是(0.2, 0.25, 0.2, 0.1,0.15,0.1 ),one-hot仍然是(0, 1, 0, 0, 0, 0),所以对于生成器来说,G输出了不同的结果但是D给出了同样的判别结果,并不能将
  • GAN的损失函数是JS散度,JS散度不适合衡量不想交分布之间的距离。

训练GAN的技巧

  • 输入规范化到(-1,1)之间,最后一层的激活函数使用tanh(BEGAN除外)
  • 使用wassertein GAN的损失函数
  • 如果有标签数据的话,尽量使用标签,也有人提出使用反转标签效果很好,另外使用标签平滑,单边标签平滑或者双边标签平滑
  • 使用mini-batch norm, 如果不用batch norm 可以使用instance norm 或者weight norm
  • 避免使用RELU和pooling层,减少稀疏梯度的可能性,可以使用leakrelu激活函数
  • 优化器尽量选择ADAM,学习率不要设置太大,初始1e-4可以参考,另外可以随着训练进行不断缩小学习率
  • 给D的网络层增加高斯噪声,相当于是一种正则

参考

标签:dim,self,boldsymbol,生成,GAN,data,对抗,size
来源: https://www.cnblogs.com/foghorn/p/15297043.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有