Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions Normalize.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
local Normalize, parent = torch.class('nn.Normalize', 'nn.Module')

function Normalize:__init(p,eps)
parent.__init(self)
assert(p,'p-norm not provided')
assert(p > 0, p..'-norm not supported')
self.p = p
self.eps = eps or 1e-10
end

function Normalize:updateOutput(input)
assert(input:dim() <= 2, 'only 1d layer supported')
local is_batch = true
if input:dim() == 1 then
input = input:view(1,-1)
is_batch = false
end

self.output:resizeAs(input)

self.norm = self.norm or input.new()
self.normp = self.normp or input.new()
self.buffer = self.buffer or input.new()

if self.p % 2 ~= 0 then
self.buffer:abs(input):pow(self.p)
else
self.buffer:pow(input,self.p)
end
self.normp:sum(self.buffer,2):add(self.eps)
self.norm:pow(self.normp,1/self.p)
self.output:cdiv(input,self.norm:view(-1,1):expandAs(self.output))

if not is_batch then
self.output = self.output[1]
end
return self.output
end

function Normalize:updateGradInput(input, gradOutput)
assert(input:dim() <= 2, 'only 1d layer supported')
assert(gradOutput:dim() <= 2, 'only 1d layer supported')

local is_batch = true
if input:dim() == 1 then
input = input:view(1,-1)
is_batch = false
end

local n = input:size(1) -- batch size
local d = input:size(2) -- dimensionality of vectors
-- compute diagonal term
self.eye = self.eye or torch.eye(d):typeAs(input):view(1,d,d)
local eyeExpand = self.eye:expand(n,d,d)
self.diag = self.diag or self.eye.new()
self.diag:cmul(eyeExpand, self.normp:view(n,1,1):expand(n,d,d))
-- compute cross term
self.buffer:abs(input):pow(self.p-2):cmul(input)
local b1 = self.buffer:view(n,d,1)
local b2 = input:view(n,1,d)

self.diag:baddbmm(-1,b1,b2)
-- compute the local gradient of the Lp transformation
self.buffer:cmul(self.normp,self.norm)
self.diag:cdiv(self.buffer:view(n,1,1):expand(n,d,d))
-- chain the gradient
self.gradInput:resize(n,d,1)
self.gradInput:bmm(self.diag, gradOutput:view(n,d,1))
self.gradInput = self.gradInput:view(n,d)

if not is_batch then
self.gradInput = self.gradInput[1]
end

return self.gradInput
end

function Normalize:__tostring__()
local s
-- different prints if the norm is integer
if self.p % 1 == 0 then
s = '%s(%d)'
else
s = '%s(%f)'
end
return string.format(s,torch.type(self),self.p)
end
18 changes: 18 additions & 0 deletions doc/simple.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* [Power](#nn.Power) : an element-wise [pow](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchpowres-x) operation ;
* [Square](#nn.Square) : an element-wise square operation ;
* [Sqrt](#nn.Sqrt) : an element-wise [sqrt](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchsqrtres-x) operation ;
* [Normalize](#nn.Normalize) : normalizes the input to have unit `L_p` norm ;
* [MM](#nn.MM) : matrix-matrix multiplication (also supports batches of matrices) ;
* Miscellaneous Modules :
* [BatchNormalization](#nn.BatchNormalization) - mean/std normalization over the mini-batch inputs (with an optional affine transform) ;
Expand Down Expand Up @@ -886,6 +887,23 @@ gnuplot.grid(true)

![](image/power.png)

<a name="nn.Normalize"></a>
## Normalize ##

```lua
module = nn.Normalize(p, [eps])
```
Normalizes the input Tensor to have unit `L_p` norm. The smoothing parameter `eps` prevents division by zero when the input contains all zero elements (default = `1e-10`).

Input can be 1D or 2D (in which case it's considered as in batch mode)

```lua
A = torch.randn(3, 5)
m = nn.Normalize(2)
B = m:forward(A) -- B is also 3 x 5
-- take the L2 norm over the second axis:
print(torch.norm(B, 2, 2)) -- norms is [1, 1, 1]
```

<a name="nn.MM"></a>
## MM ##
Expand Down
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ include('WeightedEuclidean.lua')
include('PairwiseDistance.lua')
include('CosineDistance.lua')
include('DotProduct.lua')
include('Normalize.lua')

include('Exp.lua')
include('Log.lua')
Expand Down
43 changes: 43 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,49 @@ function nntest.Power()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end

function nntest.Normalize()
-- compare forward against torch implementation
-- and check gradient
for _,p in pairs({1,2,1.5}) do
local ini = math.random(3,10)
local input = torch.randn(ini)
local module = nn.Normalize(p)
local out = module:forward(input)
local expected = torch.div(input,input:norm(p))
mytester:assertTensorEq(out, expected, 1e-7,
torch.typename(module) ..' (' .. p ..') - forward err ')

local err = jac.testJacobian(module, input, -2, 2)
mytester:assertlt(err, precision, 'error norm '..p..' on state ')
end

-- batch mode
for _,p in pairs({1,2,torch.uniform()*math.random(1,10)}) do
local ini = math.random(3,5)
local inj = math.random(3,5)
local ink = math.random(3,5)
local input = torch.Tensor(inj, ini):zero()

local module = nn.Normalize(p)

local err = jac.testJacobian(module, input, -2, 2)
mytester:assertlt(err, precision, 'error norm '..p..' on state ')
end

-- test IO correctness
local ini = math.random(3,5)
local inj = math.random(3,5)
local ink = math.random(3,5)
local input = torch.Tensor(inj, ini):zero()

local module = nn.Normalize(2)

local ferr, berr = jac.testIO(module,input, 0.1, 2)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')

end

function nntest.Square()
local in1 = torch.rand(5,7)
local module = nn.Square()
Expand Down