diff --git a/Project.toml b/Project.toml index 8edf6e31..a62c6a87 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.2.25" +version = "0.2.26" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 9ce0bd2a..39225d1a 100644 --- a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -1,4 +1,4 @@ -using Adapt: Adapt, WrappedArray +using Adapt: Adapt, WrappedArray, adapt using ArrayLayouts: zero! using BlockArrays: BlockArrays, @@ -337,60 +337,29 @@ function Base.Array(a::AnyAbstractBlockSparseArray) return Array{eltype(a)}(a) end -using SparseArraysBase: ReplacedUnstoredSparseArray - -# Wraps a block sparse array but replaces the unstored values. -# This is used in printing in order to customize printing -# of zero/unstored values. -struct ReplacedUnstoredBlockSparseArray{T,N,F,Parent<:AbstractArray{T,N}} <: - AbstractBlockSparseArray{T,N} - parent::Parent - getunstoredblock::F -end -Base.parent(a::ReplacedUnstoredBlockSparseArray) = a.parent -Base.axes(a::ReplacedUnstoredBlockSparseArray) = axes(parent(a)) -function BlockArrays.blocks(a::ReplacedUnstoredBlockSparseArray) - return ReplacedUnstoredSparseArray(blocks(parent(a)), a.getunstoredblock) -end - -# This is copied from `SparseArraysBase.jl` since it is not part -# of the public interface. -# Like `Char` but prints without quotes. -struct UnquotedChar <: AbstractChar - char::Char +function SparseArraysBase.isstored( + A::AnyAbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N} +) where {N} + bI = BlockIndex(findblockindex.(axes(A), I)) + bA = blocks(A) + return isstored(bA, bI.I...) && isstored(bA[bI.I...], bI.α...) end -Base.show(io::IO, c::UnquotedChar) = print(io, c.char) -Base.show(io::IO, ::MIME"text/plain", c::UnquotedChar) = show(io, c) -using FillArrays: Fill -struct GetUnstoredBlockShow{Axes} - axes::Axes -end -@inline function (f::GetUnstoredBlockShow)( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - # TODO: Make sure this works for sparse or block sparse blocks, immutable - # blocks, diagonal blocks, etc.! - b_size = ntuple(ndims(a)) do d - return length(f.axes[d][Block(I[d])]) +function Base.replace_in_print_matrix( + A::AnyAbstractBlockSparseArray{<:Any,2}, i::Integer, j::Integer, s::AbstractString +) + return isstored(A, i, j) ? s : Base.replace_with_centered_mark(s) +end + +# attempt to catch things that wrap GPU arrays +function Base.print_array(io::IO, X::AnyAbstractBlockSparseArray) + X_cpu = adapt(Array, X) + if typeof(X_cpu) === typeof(X) # prevent infinite recursion + # need to specify ndims to allow specialized code for vector/matrix + @allowscalar @invoke Base.print_array( + io, X_cpu::AbstractArray{eltype(X_cpu),ndims(X_cpu)} + ) + else + Base.print_array(io, X_cpu) end - return Fill(UnquotedChar('.'), b_size) -end -# TODO: Use `Base.to_indices`. -@inline function (f::GetUnstoredBlockShow)( - a::AbstractArray{<:Any,N}, I::CartesianIndex{N} -) where {N} - return f(a, Tuple(I)...) -end - -# TODO: Make this an `@interface ::AbstractBlockSparseArrayInterface` function -# once we delete the hacky `Base.show` definitions in `BlockSparseArraysTensorAlgebraExt`. -function Base.show(io::IO, mime::MIME"text/plain", a::AnyAbstractBlockSparseArray) - summary(io, a) - isempty(a) && return nothing - print(io, ":") - println(io) - a′ = ReplacedUnstoredBlockSparseArray(a, GetUnstoredBlockShow(axes(a))) - @allowscalar Base.print_array(io, a′) - return nothing end diff --git a/test/test_basics.jl b/test/test_basics.jl index fe3418c6..da93d228 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1139,7 +1139,7 @@ arrayts = (Array, JLArray) a = BlockSparseMatrix{elt,arrayt{elt,2}}([2, 2], [2, 2]) @allowscalar a[1, 2] = 12 @test sprint(show, "text/plain", a) == - "$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ . .\n $(zero(eltype(a))) $(zero(eltype(a))) │ . .\n ───────────┼──────\n . . │ . .\n . . │ . ." + "$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅ \n ───────────┼──────────\n ⋅ ⋅ │ ⋅ ⋅ \n ⋅ ⋅ │ ⋅ ⋅ " end end @testset "TypeParameterAccessors.position" begin