Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ New library functions
* `copyuntil(out, io, delim)` and `copyline(out, io)` copy data into an `out::IO` stream ([#48273]).
* `eachrsplit(string, pattern)` iterates split substrings right to left.
* `Sys.username()` can be used to return the current user's username ([#51897]).
* `wrap(Array, m::Union{MemoryRef{T}, Memory{T}}, dims)` which is the safe counterpart to `unsafe_wrap` ([#52049]).

New library features
--------------------
Expand Down
33 changes: 33 additions & 0 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3039,3 +3039,36 @@ intersect(r::AbstractRange, v::AbstractVector) = intersect(v, r)
_getindex(v, i)
end
end

"""
wrap(Array, m::Union{Memory{T}, MemoryRef{T}}, dims)

Create an array of size `dims` using `m` as the underlying memory. This can be thought of as a safe version
of [`unsafe_wrap`](@ref) utilizing `Memory` or `MemoryRef` instead of raw pointers.
"""
function wrap end

@eval @propagate_inbounds function wrap(::Type{Array}, ref::MemoryRef{T}, dims::NTuple{N, Integer}) where {T, N}
mem = ref.mem
mem_len = length(mem) + 1 - memoryrefoffset(ref)
len = Core.checked_dims(dims...)
@boundscheck mem_len >= len || invalid_wrap_err(mem_len, dims, len)
if N != 1 && !(ref === GenericMemoryRef(mem) && len === mem_len)
mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len)
ref = MemoryRef(mem)
end
$(Expr(:new, :(Array{T, N}), :ref, :dims))
end

@noinline invalid_wrap_err(len, dims, proddims) = throw(DimensionMismatch(
"Attempted to wrap a MemoryRef of length $len with an Array of size dims=$dims, which is invalid because prod(dims) = $proddims > $len, so that the array would have more elements than the underlying memory can store."))

function wrap(::Type{Array}, m::Memory{T}, dims::NTuple{N, Integer}) where {T, N}
wrap(Array, MemoryRef(m), dims)
end
function wrap(::Type{Array}, m::MemoryRef{T}, l::Integer) where {T}
wrap(Array, m, (l,))
end
function wrap(::Type{Array}, m::Memory{T}, l::Integer) where {T}
wrap(Array, MemoryRef(m), (l,))
end
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ export
vcat,
vec,
view,
wrap,
zeros,

# search, find, match and related functions
Expand Down
19 changes: 19 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3170,3 +3170,22 @@ end
@test c + zero(c) == c
end
end

@testset "Wrapping Memory into Arrays" begin
mem = Memory{Int}(undef, 10) .= 1
memref = MemoryRef(mem)
@test_throws DimensionMismatch wrap(Array, mem, (10, 10))
@test wrap(Array, mem, (5,)) == ones(Int, 5)
@test wrap(Array, mem, 2) == ones(Int, 2)
@test wrap(Array, memref, 10) == ones(Int, 10)
@test wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2)
@test wrap(Array, mem, (5, 2)) == ones(Int, 5, 2)

memref2 = MemoryRef(mem, 3)
@test wrap(Array, memref2, (5,)) == ones(Int, 5)
@test wrap(Array, memref2, 2) == ones(Int, 2)
@test wrap(Array, memref2, (2,2,2)) == ones(Int,2,2,2)
@test wrap(Array, memref2, (3, 2)) == ones(Int, 3, 2)
@test_throws DimensionMismatch wrap(Array, memref2, 9)
@test_throws DimensionMismatch wrap(Array, memref2, 10)
end