ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

GFPGAN源码分析—第六篇

2021-12-06 23:05:38  阅读:412  来源: 互联网

标签:size narrow nn self channels 源码 GFPGAN 第六篇 out


2021SC@SDUSC

源码:archs\gfpganv1_clean_arch.py

本篇主要分析gfpganv1_clean_arch.py下的

class GFPGANv1Clean(nn.Module)类_init_()方法

目录

class GFPGANv1Clean(nn.Module)

init()

(1)channels的设置

(2)调用torch.nn.Conv2d()创建了一层卷积神经网络

(3)下采样(downsample)

(4)上采样(upsample)

(5)全连接层

(6)创建self.stylegan_decoder

(7)如果decoder_load_path不为空则读取

(8)for SFT(SFT layer)


class GFPGANv1Clean(nn.Module)

        继承自nn.Module类,使得我们可以使用很多现成的类,比如本类中使用的Conv2d以及RelU激活函数等等。

init()

参数:

self,
out_size,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False

在class GFPGANer()-init()中被调用时:

self.gfpgan = GFPGANv1Clean(
    out_size=512,
    num_style_feat=512,
    channel_multiplier=channel_multiplier,
    decoder_load_path=None,
    fix_decoder=False,
    num_mlp=8,
    input_is_latent=True,
    different_w=True,
    narrow=1,
    sft_half=True)

(1)channels的设置

实际调用的时候narrow=1,

channels保存了经过convolution层后的输出的通道数

unet_narrow = narrow * 0.5

channels = {
    '4': int(512 * unet_narrow),
    '8': int(512 * unet_narrow),
    '16': int(512 * unet_narrow),
    '32': int(512 * unet_narrow),
    '64': int(256 * channel_multiplier * unet_narrow),
    '128': int(128 * channel_multiplier * unet_narrow),
    '256': int(64 * channel_multiplier * unet_narrow),
    '512': int(32 * channel_multiplier * unet_narrow),
    '1024': int(16 * channel_multiplier * unet_narrow)
}

(2)调用torch.nn.Conv2d()搭建卷积神经网络

#out_size=512,so log_size=9
self.log_size = int(math.log(out_size, 2))
#first_out_size = 512
first_out_size = 2 ** (int(math.log(out_size, 2)))
#channels['512']=32*2*0.5=32
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)

在这里介绍一下nn.Conv2d()的几个参数

in_channels: int,#输入的通道数目【必选】
out_channels: int,# 输出的通道数目【必选】
kernel_size: _size_2_t,#卷积核的大小,类型为int(方形边长) 或者元组(长和宽)【必选】
stride: _size_2_t = 1,#步长
padding: Union[str, _size_2_t] = 0,#边界增益,可以控制输出结果的尺寸
dilation: _size_2_t = 1,#控制卷积核之间的间距
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',  # TODO: refine this type
device=None,
dtype=None

那么可以得知

self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)


#实际上是传入通道为3(RGB)的输入,使用边长为1的卷积核,最后获得通道为32的输出
#由于卷积核边长为1,我们输入与输入的图片大小仍然保持一致,但增加了通道数

(3)下采样(downsample)

可以看到实际上是调用ResBlock做了下采样

# 输入图片的通道数(实际为32)
in_channels = channels[f'{first_out_size}']
 #创建ModuleList容器
self.conv_body_down = nn.ModuleList()
# i从self.log_size(9)->3      :7次循环
for i in range(self.log_size, 2, -1):
    out_channels = channels[f'{2 ** (i - 1)}']
    #调用ResBlock残差网络做下采样,并将该module添加到设置的ModuleList
    self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
    #这一层的输出管道数作为下一层输入的管道数
    in_channels = out_channels

介绍一下nn.ModuleList()

nn.ModuleList,它是一个储存不同module,并自动将每个 module 的 parameters 添加到网络之中的容器。你可以把任意 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,无非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中。
#注意nn.ModuleList则没有实现内部forward函数,所以需要手动实现

最后一层卷积层的搭建:

#最终输出通道数为channels['4']=256,使用边长为3的卷积核,步长为1,padding为1,保证维度不变
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)

(4)上采样(upsample)

#输入通道数为channels['4']=256,即下采样的输出的通道数
        in_channels = channels['4']
        #创建ModuleList容器
        self.conv_body_up = nn.ModuleList()
        # i从3->self.log_size(9)     :7次循环
        for i in range(3, self.log_size + 1):
            # 定义输出的通道数
            out_channels = channels[f'{2 ** i}']
            # 调用带有上采样ResBlock残差网络,并将该module添加到设置的ModuleList
            self.conv_body_up.append(ResBlock(in_channels, out_channels, 
                                              mode='up'))
            #这一层的输出管道数作为下一层输入的管道数
            in_channels = out_channels

(5)全连接层

根据传入的参数different_w,选择每个输出样本的大小,并搭建相应的全连接层。

if different_w:
    #16*512=8192
    linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
    print(linear_out_channel)
else:
    #512
    linear_out_channel = num_style_feat
#全连接层size of each input sample:4096,size of each output sample:8192
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)

(6)创建self.stylegan_decoder

self.stylegan_decoder = StyleGAN2GeneratorCSFT(
    out_size=out_size,
    num_style_feat=num_style_feat,
    num_mlp=num_mlp,
    channel_multiplier=channel_multiplier,
    narrow=narrow,
    sft_half=sft_half)

(7)如果decoder_load_path不为空则读取

if decoder_load_path:
    self.stylegan_decoder.load_state_dict(
        torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
if fix_decoder:
    for name, param in self.stylegan_decoder.named_parameters():
        param.requires_grad = False

(8)for SFT(SFT layer)

#ModuleList
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
  # i从3->self.log_size(9)     :7次循环
for i in range(3, self.log_size + 1):
    # 定义输出的通道数
    out_channels = channels[f'{2 ** i}']
     #输出通道数是否减半
    if sft_half:
        sft_out_channels = out_channels
    else:
        sft_out_channels = out_channels * 2
         #使用nn.Sequential搭建网络,并添加到ModuleList
    self.condition_scale.append(
        nn.Sequential(
             #卷积核边长为3,步长为1,输出与输出保持相同维度
            nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, 
                                                                         True),
            nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
    self.condition_shift.append(
        nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, 
                                                                         True),
            nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))

nn.Sequential是一个有序的容器,其中传入的是构造器类(各种用来处理input的类),最终input会被Sequential中的构造器依次执行。

标签:size,narrow,nn,self,channels,源码,GFPGAN,第六篇,out
来源: https://blog.csdn.net/Vaifer233/article/details/121758019

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有