ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

动手实现深度学习(13)池化层的实现

2022-09-12 19:05:05  阅读:218  来源: 互联网

标签:13 池化层 实现 self stride shape data pool out


10.1 池化层的运算

传送门: https://www.cnblogs.com/greentomlee/p/12314064.html

github: Leezhen2014: https://github.com/Leezhen2014/python_deep_learning

池化层的forward

Pool分为三类 mean-pool, max-pool和min-pool, 本章只讨论max-pool

以下是forwad的运算:

https://blog.csdn.net/nanhuaibeian/article/details/100664570 wps82

池化层的backward的运算

Max-pool的反传是将原来的单元扩大stride_h*stride_w,其余的地方填充0

wps83

 

10.2 池化层的实现

  1 class Pooling:
  2     def __init__(self, pool_h, pool_w, stride=1, pad=0):
  3         self.pool_h = pool_h
  4         self.pool_w = pool_w
  5         self.stride = stride
  6         self.pad = pad
  7 
  8         self.x = None
  9         self.arg_max = None
 10 
 11     def forward(self, x):
 12         N, C, H, W = x.shape
 13         out_h = int(1 + (H - self.pool_h) / self.stride)
 14         out_w = int(1 + (W - self.pool_w) / self.stride)
 15 
 16         col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
 17         col = col.reshape(-1, self.pool_h * self.pool_w)
 18 
 19         arg_max = np.argmax(col, axis=1)
 20         out = np.max(col, axis=1)
 21         out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)
 22 
 23         self.x = x
 24         self.arg_max = arg_max
 25 
 26         return out
 27 
 28     def backward(self, dout):
 29         dout = dout.transpose(0, 2, 3, 1)
 30 
 31         pool_size = self.pool_h * self.pool_w
 32         dmax = np.zeros((dout.size, pool_size))
 33         dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
 34         dmax = dmax.reshape(dout.shape + (pool_size,))
 35 
 36         dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
 37         dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
 38 
 39         return dx

 

10.3 pool单元测试

测试的数据如下:image

 

im2col以后的数据:image

 

Maxpool以后的数据:image

 

测试程序:

  1 # -*- coding: utf-8 -*-
  2 # @File  : test_im2col.py
  3 # @Author: lizhen
  4 # @Date  : 2020/2/14
  5 # @Desc  : 测试im2col
  6 import numpy as np
  7 
  8 from src.common.util import im2col,col2im
  9 from src.common.layers import Convolution,Pooling
 10 
 11 
 12 if __name__ == '__main__':
 13     raw_data = [3, 0, 4, 2,
 14                 6, 5, 4, 3,
 15                 3, 0, 2, 3,
 16                 1, 0, 3, 1,
 17 
 18                 1, 2, 0, 1,
 19                 3, 0, 2, 4,
 20                 1, 0, 3, 2,
 21                 4, 3, 0, 1,
 22 
 23                 4, 2, 0, 1,
 24                 1, 2, 0, 4,
 25                 3, 0, 4, 2,
 26                 6, 2, 4, 5
 27     ]
 28 
 29     raw_filter=[
 30         1,    1,    1,    1,    1,    1,
 31         1,    1,    1,    1,    1,    1,
 32         2,    2,    2,    2,    2,   2,
 33         2,    2,    2,    2,    2,   2,
 34 
 35     ]
 36 
 37 
 38 
 39     input_data = np.array(raw_data)
 40     filter_data = np.array(raw_filter)
 41 
 42     x = input_data.reshape(1,3,4,4)# NCHW
 43     W = filter_data.reshape(2,3,2,2) # NHWC
 44     b = np.zeros(2)
 45     # b = b.reshape((2,1))
 46     # col1 = im2col(input_data=x,filter_h=2,filter_w=2,stride=1,pad=0)#input_data, filter_h, filter_w, stride=1, pad=0
 47     # print(col1)
 48 
 49     # print("input_data.shape=%s"%str(input_data.shape))
 50     # print("W.shape=%s"%str(W.shape))
 51     # print("b.shape=%s"%str(b.shape))
 52     # conv = Convolution(W,b) # def __init__(self, W, b, stride=1, pad=0)
 53     # out = conv.forward(x)
 54     # print("bout.shape=%s"%str(out.shape))
 55     # print(out)
 56 
 57     print("===================")
 58     pool=Pooling( pool_h=2, pool_w=2, stride=2, pad=0)
 59     out = pool.forward(x)
 60     print(out.shape)
 61     print(out)

对应输出:

wps84

标签:13,池化层,实现,self,stride,shape,data,pool,out
来源: https://www.cnblogs.com/greentomlee/p/16686936.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有