From 0c4208d94ac7e7261a4936aef2db00d2fcd7d073 Mon Sep 17 00:00:00 2001 From: marsggbo Date: Thu, 15 Mar 2018 12:15:03 +0800 Subject: [PATCH 1/4] spp_layer module for pytorch --- SPP_Layer.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 SPP_Layer.py diff --git a/SPP_Layer.py b/SPP_Layer.py new file mode 100644 index 0000000..10fefa4 --- /dev/null +++ b/SPP_Layer.py @@ -0,0 +1,46 @@ +#coding=utf-8 + +import math +import torch +import torch.nn.functional as F + +# 构建SPP层(空间金字塔池化层) +# Building SPP layer +class SPPLayer(torch.nn.Module): + + def __init__(self, num_levels, pool_type='max_pool'): + super(SPPLayer, self).__init__() + + self.num_levels = num_levels + self.pool_type = pool_type + + def forward(self, x): + # num:样本数量 c:通道数 h:高 w:宽 + # num: the number of samples + # c: the number of channels + # h: height + # w: width + num, c, h, w = x.size() + for i in range(self.num_levels): + level = i+1 + + ''' + The equation is explained on the following site: + http://www.cnblogs.com/marsggbo/p/8572846.html#autoid-0-0-0 + ''' + kernel_size = (math.ceil(h / level), math.ceil(w / level)) + stride = (math.ceil(h / level), math.ceil(w / level)) + pooling = (math.floor((kernel_size[0]*level-h+1)/2), math.floor((kernel_size[1]*level-w+1)/2)) + + # 选择池化方式 + if self.pool_type == 'max_pool': + tensor = F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, -1) + else: + tensor = F.avg_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, -1) + + # 展开、拼接 + if (i == 0): + x_flatten = tensor.view(num, -1) + else: + x_flatten = torch.cat((x_flatten, tensor.view(num, -1)), 1) + return x_flatten From cb6b5a6f2f8286e44d68fbdbf210f9f57979b71a Mon Sep 17 00:00:00 2001 From: marsggbo Date: Fri, 16 Mar 2018 11:16:16 +0800 Subject: [PATCH 2/4] Update SPP_Layer.py --- SPP_Layer.py | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/SPP_Layer.py b/SPP_Layer.py index 10fefa4..1ea744e 100644 --- a/SPP_Layer.py +++ b/SPP_Layer.py @@ -38,6 +38,77 @@ def forward(self, x): else: tensor = F.avg_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, -1) + # 展开、拼接 + if (i == 0): + x_flatten = tensor.view(num, -1) + else: + x_flatten = torch.cat((x_flatten, tensor.view(num, -1)), 1) + return x_flatten + +# 上面的代码在当数据大小比较小的时候可能会出现下面这种恶心的错误, 即 padding的大小需要小于kernel一半的大小, +# 所以为了解决这个问题,下面代码作了进一步修改,主要方法就是先对数据进行手动更新padding,然后再计算出此时的kernel和stride +# 经测试即使输入数据大小是(7,9), spp_level=4也是正常运行的 + +# The above code may cause the following nausea error when the data size is relatively small, +# that is, the padding size needs to be less than half the size of the kernel, so in order to solve this problem, +# the following code is further modified, the main method is first to padding the data +# and then the kernel and stride are calculated. +# Tested even if the input data size is (7,9), spp_level=4, the code can run successfully. + + +class Modified_SPPLayer(torch.nn.Module): + + def __init__(self, num_levels, pool_type='max_pool'): + super(SPPLayer, self).__init__() + + self.num_levels = num_levels + self.pool_type = pool_type + + + def forward(self, x): + # num:样本数量 c:通道数 h:高 w:宽 + # num: the number of samples + # c: the number of channels + # h: height + # w: width + num, c, h, w = x.size() +# print(x.size()) + for i in range(self.num_levels): + level = i+1 + + ''' + The equation is explained on the following site: + http://www.cnblogs.com/marsggbo/p/8572846.html#autoid-0-0-0 + ''' + kernel_size = (math.ceil(h / level), math.ceil(w / level)) + stride = (math.floor(h / level), math.floor(w / level)) + pooling = (math.floor((kernel_size[0]*level-h+1)/2), math.floor((kernel_size[1]*level-w+1)/2)) + + # update input data with padding + zero_pad = torch.nn.ZeroPad2d((pooling[1],pooling[1],pooling[0],pooling[0])) + x_new = zero_pad(x) + + # update kernel and stride + h_new = 2*pooling[0] + h + w_new = 2*pooling[1] + w + + kernel_size = (math.ceil(h_new / level), math.ceil(w_new / level)) + stride = (math.floor(h_new / level), math.floor(w_new / level)) + + + # 选择池化方式 + if self.pool_type == 'max_pool': + try: + tensor = F.max_pool2d(x_new, kernel_size=kernel_size, stride=stride).view(num, -1) + except Exception as e: + print(str(e)) + print(x.size()) + print(level) + else: + tensor = F.avg_pool2d(x_new, kernel_size=kernel_size, stride=stride).view(num, -1) + + + # 展开、拼接 if (i == 0): x_flatten = tensor.view(num, -1) From 17e66d8b73a09ddc5a154c716e1989a5292cc755 Mon Sep 17 00:00:00 2001 From: marsggbo Date: Fri, 26 Jul 2019 22:52:56 +0800 Subject: [PATCH 3/4] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 1d0824e..371fcb0 100644 --- a/README.md +++ b/README.md @@ -6,3 +6,4 @@ The function `spatial_pyramid_pool()` in file `spp_layer.py` is independent. It

See this:Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition +`SPP_Layer.py` provides a torch.nn.Module of spp_layer which can be inserted into any models very easily. From 98e4ffbcd274ec7d8d2afc04718c620f7ca7a267 Mon Sep 17 00:00:00 2001 From: marsggbo Date: Fri, 26 Jul 2019 22:53:17 +0800 Subject: [PATCH 4/4] Update README.md --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 371fcb0..58d4d8f 100644 --- a/README.md +++ b/README.md @@ -6,4 +6,7 @@ The function `spatial_pyramid_pool()` in file `spp_layer.py` is independent. It

See this:Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition + + + `SPP_Layer.py` provides a torch.nn.Module of spp_layer which can be inserted into any models very easily.