Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,59 @@ prod{T}(A::AbstractArray{T}, region) = _prod!(reduction_init(A, region, one(T)*o

prod(A::AbstractArray{Bool}, region) = error("use all() instead of prod() for boolean arrays")


### findmin/findmax
# Generate the body for a reduction function reduce!(f, Rval, Rind, A), using a comparison operator f
# Rind contains the index of A from which Rval was taken
function gen_findreduction_body(N, f::Function)
F = Expr(:quote, f)
quote
(isempty(Rval) || isempty(A)) && return Rval, Rind
for i = 1:$N
(size(Rval, i) == size(A, i) || size(Rval, i) == 1) || throw(DimensionMismatch("Find-reduction on array of size $(size(A)) with output of size $(size(Rval))"))
size(Rval, i) == size(Rind, i) || throw(DimensionMismatch("Find-reduction: outputs must be of the same size"))
end
@nexprs $N d->(sizeR_d = size(Rval,d))
# If we're reducing along dimension 1, for efficiency we can make use of a temporary.
# Otherwise, keep the result in Rval/Rind so that we traverse A in storage order.
k = 0
@inbounds if size(Rval, 1) < size(A, 1)
@nloops $N i d->(d>1? (1:size(A,d)) : (1:1)) d->(j_d = sizeR_d==1 ? 1 : i_d) begin
tmpRv = (@nref $N Rval j)
tmpRi = (@nref $N Rind j)
for i_1 = 1:size(A,1)
k += 1
tmpAv = (@nref $N A i)
if ($F)(tmpAv, tmpRv)
tmpRv = tmpAv
tmpRi = k
end
end
(@nref $N Rval j) = tmpRv
(@nref $N Rind j) = tmpRi
end
else
@nloops $N i A d->(j_d = sizeR_d==1 ? 1 : i_d) begin
k += 1
tmpAv = (@nref $N A i)
if ($F)(tmpAv, (@nref $N Rval j))
(@nref $N Rval j) = tmpAv
(@nref $N Rind j) = k
end
end
end
Rval, Rind
end
end

eval(ngenerate(:N, :(typeof((Rval,Rind))), :(_findmin!{T,N}(Rval::AbstractArray, Rind::AbstractArray, A::AbstractArray{T,N})), N->gen_findreduction_body(N, <)))
findmin!{R}(rval::AbstractArray{R}, rind::AbstractArray, A::AbstractArray; init::Bool=true) = _findmin!(initarray!(rval, typemax(R), init), rind, A)
findmin{T}(A::AbstractArray{T}, region) =
isempty(A) ? (similar(A,reduced_dims0(A,region)), zeros(Int,reduced_dims0(A,region))) :
_findmin!(reduction_init(A, region, typemax(T)), zeros(Int,reduced_dims0(A,region)), A)

eval(ngenerate(:N, :(typeof((Rval,Rind))), :(_findmax!{T,N}(Rval::AbstractArray, Rind::AbstractArray, A::AbstractArray{T,N})), N->gen_findreduction_body(N, >)))
findmax!{R}(rval::AbstractArray{R}, rind::AbstractArray, A::AbstractArray; init::Bool=true) = _findmax!(initarray!(rval, typemin(R), init), rind, A)
findmax{T}(A::AbstractArray{T}, region) =
isempty(A) ? (similar(A,reduced_dims0(A,region)), zeros(Int,reduced_dims0(A,region))) :
_findmax!(reduction_init(A, region, typemin(T)), zeros(Int,reduced_dims0(A,region)), A)
10 changes: 10 additions & 0 deletions doc/stdlib/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -608,10 +608,20 @@ Iterable Collections

Returns the maximum element and its index.

.. function:: findmax(A, dims) -> (maxval, index)

For an array input, returns the value and index of the maximum over
the given dimensions.

.. function:: findmin(itr) -> (x, index)

Returns the minimum element and its index.

.. function:: findmin(A, dims) -> (minval, index)

For an array input, returns the value and index of the minimum over
the given dimensions.

.. function:: sum(itr)

Returns the sum of all elements in a collection.
Expand Down
10 changes: 10 additions & 0 deletions test/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,13 @@ R = reducedim((a,b) -> a+b, [1 2; 3 4], 2, 0.0)
# inferred return types
rt = Base.return_types(reducedim, (Function, Array{Float64, 3}, Int, Float64))
@test length(rt) == 1 && rt[1] == Array{Float64, 3}

## findmin/findmax
A = [1.0 3.0 6.0;
5.0 2.0 4.0]
@test findmin(A, (1,)) == ([1.0 2.0 4.0], [1 4 6])
@test findmin(A, (2,)) == (reshape([1.0,2.0], 2, 1), reshape([1,4], 2, 1))
@test findmin(A, (1,2)) == (fill(1.0,1,1),fill(1,1,1))
@test findmax(A, (1,)) == ([5.0 3.0 6.0], [2 3 5])
@test findmax(A, (2,)) == (reshape([6.0,5.0], 2, 1), reshape([5,2], 2, 1))
@test findmax(A, (1,2)) == (fill(6.0,1,1),fill(5,1,1))