Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ New library functions
* `copyuntil(out, io, delim)` and `copyline(out, io)` copy data into an `out::IO` stream ([#48273]).
* `eachrsplit(string, pattern)` iterates split substrings right to left.
* `Sys.username()` can be used to return the current user's username ([#51897]).
* `wrap(Array, m::Union{MemoryRef{T}, Memory{T}}, dims)` which is the safe counterpart to `unsafe_wrap`.

New library features
--------------------
Expand Down
28 changes: 28 additions & 0 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3026,3 +3026,31 @@ intersect(r::AbstractRange, v::AbstractVector) = intersect(v, r)
_getindex(v, i)
end
end

"""
wrap(Array, m::Union{Memory{T}, MemoryRef{T}}, dims)

Create an array of size `dims` using `m` as the underlying memory. This can be thought of as a safe version
of [`unsafe_wrap`](@ref) utilizing `Memory` or `MemoryRef` instead of raw pointers.
"""
@propagate_inbounds function wrap(::Type{Array}, m::MemoryRef{T}, dims::NTuple{N, Integer}) where {T, N}
len = length(m.mem)
@boundscheck len >= prod(dims) || invalid_wrap_err(len, dims)
_wrap(Array, m, convert(Tuple{Vararg{Int}}, dims))
end
@noinline invalid_wrap_err(len, dims) = throw(DimensionMismatch(
"Attempted to wrap a MemoryRef of length $len with an Array of size dims=$dims, which is invalid because prod(dims) = $(prod(dims)) > $len, so that the array would have more elements than the underlying memory can store."))

function wrap(::Type{Array}, m::Memory{T}, dims::NTuple{N, Integer}) where {T, N}
wrap(Array, MemoryRef(m), dims)
end
function wrap(::Type{Array}, m::MemoryRef{T}, l::Integer) where {T}
wrap(Array, m, (l,))
end
function wrap(::Type{Array}, m::Memory{T}, l::Integer) where {T}
wrap(Array, MemoryRef(m), (l,))
end

@eval @inline function _wrap(::Type{Array}, m::MemoryRef{T}, dims::NTuple{N, Int}) where {T, N}
$(Expr(:new, :(Array{T, N}), :m, :dims))
end
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ export
vcat,
vec,
view,
wrap,
zeros,

# search, find, match and related functions
Expand Down
14 changes: 14 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3135,3 +3135,17 @@ end
@test c + zero(c) == c
end
end

@testset "Wrapping Memory into Arrays" begin
mem = Memory{Int}(undef, 10) .= 1
memref = MemoryRef(mem)
@test_throws DimensionMismatch wrap(Array, mem, (10, 10))
@test wrap(Array, mem, (5,)) == ones(Int, 5)
@test wrap(Array, mem, 2) == ones(Int, 2)
@test wrap(Array, memref, 10) == ones(Int, 10)

# This is broken because length(a::Array{T, N>1}) is currently doing length(a.ref.mem) !!!
@test_broken wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2)
# This works because 5 * 2 happens to equal 10 (the length of mem)
@test wrap(Array, mem, (5, 2)) == ones(Int, 5, 2)
end