-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathSpatialCrossMapLRN_temp.py
114 lines (86 loc) · 3.38 KB
/
SpatialCrossMapLRN_temp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# This is a simple modification of https://github.com/pytorch/pytorch/blob/master/torch/legacy/nn/SpatialCrossMapLRN.py.
import torch
from torch.legacy.nn.Module import Module
from torch.legacy.nn.utils import clear
class SpatialCrossMapLRN_temp(Module):
def __init__(self, size, alpha=1e-4, beta=0.75, k=1, gpuDevice=0):
super(SpatialCrossMapLRN_temp, self).__init__()
self.size = size
self.alpha = alpha
self.beta = beta
self.k = k
self.scale = None
self.paddedRatio = None
self.accumRatio = None
self.gpuDevice = gpuDevice
def updateOutput(self, input):
assert input.dim() == 4
if self.scale is None:
self.scale = input.new()
if self.output is None:
self.output = input.new()
batchSize = input.size(0)
channels = input.size(1)
inputHeight = input.size(2)
inputWidth = input.size(3)
if input.is_cuda:
self.output = self.output.cuda(self.gpuDevice)
self.scale = self.scale.cuda(self.gpuDevice)
self.output.resize_as_(input)
self.scale.resize_as_(input)
# use output storage as temporary buffer
inputSquare = self.output
torch.pow(input, 2, out=inputSquare)
prePad = int((self.size - 1) / 2 + 1)
prePadCrop = channels if prePad > channels else prePad
scaleFirst = self.scale.select(1, 0)
scaleFirst.zero_()
# compute first feature map normalization
for c in range(prePadCrop):
scaleFirst.add_(inputSquare.select(1, c))
# reuse computations for next feature maps normalization
# by adding the next feature map and removing the previous
for c in range(1, channels):
scalePrevious = self.scale.select(1, c - 1)
scaleCurrent = self.scale.select(1, c)
scaleCurrent.copy_(scalePrevious)
if c < channels - prePad + 1:
squareNext = inputSquare.select(1, c + prePad - 1)
scaleCurrent.add_(1, squareNext)
if c > prePad:
squarePrevious = inputSquare.select(1, c - prePad)
scaleCurrent.add_(-1, squarePrevious)
self.scale.mul_(self.alpha / self.size).add_(self.k)
torch.pow(self.scale, -self.beta, out=self.output)
self.output.mul_(input)
return self.output
def updateGradInput(self, input, gradOutput):
assert input.dim() == 4
batchSize = input.size(0)
channels = input.size(1)
inputHeight = input.size(2)
inputWidth = input.size(3)
if self.paddedRatio is None:
self.paddedRatio = input.new()
if self.accumRatio is None:
self.accumRatio = input.new()
self.paddedRatio.resize_(channels + self.size - 1, inputHeight, inputWidth)
self.accumRatio.resize_(inputHeight, inputWidth)
cacheRatioValue = 2 * self.alpha * self.beta / self.size
inversePrePad = int(self.size - (self.size - 1) / 2)
self.gradInput.resize_as_(input)
torch.pow(self.scale, -self.beta, out=self.gradInput).mul_(gradOutput)
self.paddedRatio.zero_()
paddedRatioCenter = self.paddedRatio.narrow(0, inversePrePad, channels)
for n in range(batchSize):
torch.mul(gradOutput[n], self.output[n], out=paddedRatioCenter)
paddedRatioCenter.div_(self.scale[n])
torch.sum(self.paddedRatio.narrow(0, 0, self.size - 1), 0, out=self.accumRatio)
for c in range(channels):
self.accumRatio.add_(self.paddedRatio[c + self.size - 1])
self.gradInput[n][c].addcmul_(-cacheRatioValue, input[n][c], self.accumRatio)
self.accumRatio.add_(-1, self.paddedRatio[c])
return self.gradInput
def clearState(self):
clear(self, 'scale', 'paddedRatio', 'accumRatio')
return super(SpatialCrossMapLRN_temp, self).clearState()