Skip to content

Simplify show implementation #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.25"
version = "0.2.26"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
77 changes: 23 additions & 54 deletions src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Adapt: Adapt, WrappedArray
using Adapt: Adapt, WrappedArray, adapt
using ArrayLayouts: zero!
using BlockArrays:
BlockArrays,
Expand Down Expand Up @@ -337,60 +337,29 @@
return Array{eltype(a)}(a)
end

using SparseArraysBase: ReplacedUnstoredSparseArray

# Wraps a block sparse array but replaces the unstored values.
# This is used in printing in order to customize printing
# of zero/unstored values.
struct ReplacedUnstoredBlockSparseArray{T,N,F,Parent<:AbstractArray{T,N}} <:
AbstractBlockSparseArray{T,N}
parent::Parent
getunstoredblock::F
end
Base.parent(a::ReplacedUnstoredBlockSparseArray) = a.parent
Base.axes(a::ReplacedUnstoredBlockSparseArray) = axes(parent(a))
function BlockArrays.blocks(a::ReplacedUnstoredBlockSparseArray)
return ReplacedUnstoredSparseArray(blocks(parent(a)), a.getunstoredblock)
end

# This is copied from `SparseArraysBase.jl` since it is not part
# of the public interface.
# Like `Char` but prints without quotes.
struct UnquotedChar <: AbstractChar
char::Char
function SparseArraysBase.isstored(

Check warning on line 340 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L340

Added line #L340 was not covered by tests
A::AnyAbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}
) where {N}
bI = BlockIndex(findblockindex.(axes(A), I))
bA = blocks(A)
return isstored(bA, bI.I...) && isstored(bA[bI.I...], bI.α...)

Check warning on line 345 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L343-L345

Added lines #L343 - L345 were not covered by tests
end
Base.show(io::IO, c::UnquotedChar) = print(io, c.char)
Base.show(io::IO, ::MIME"text/plain", c::UnquotedChar) = show(io, c)

using FillArrays: Fill
struct GetUnstoredBlockShow{Axes}
axes::Axes
end
@inline function (f::GetUnstoredBlockShow)(
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
) where {N}
# TODO: Make sure this works for sparse or block sparse blocks, immutable
# blocks, diagonal blocks, etc.!
b_size = ntuple(ndims(a)) do d
return length(f.axes[d][Block(I[d])])
function Base.replace_in_print_matrix(

Check warning on line 348 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L348

Added line #L348 was not covered by tests
A::AnyAbstractBlockSparseArray{<:Any,2}, i::Integer, j::Integer, s::AbstractString
)
return isstored(A, i, j) ? s : Base.replace_with_centered_mark(s)

Check warning on line 351 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L351

Added line #L351 was not covered by tests
end

# attempt to catch things that wrap GPU arrays
function Base.print_array(io::IO, X::AnyAbstractBlockSparseArray)
X_cpu = adapt(Array, X)
if typeof(X_cpu) === typeof(X) # prevent infinite recursion

Check warning on line 357 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L355-L357

Added lines #L355 - L357 were not covered by tests
# need to specify ndims to allow specialized code for vector/matrix
@allowscalar @invoke Base.print_array(

Check warning on line 359 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L359

Added line #L359 was not covered by tests
io, X_cpu::AbstractArray{eltype(X_cpu),ndims(X_cpu)}
)
else
Base.print_array(io, X_cpu)

Check warning on line 363 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L363

Added line #L363 was not covered by tests
end
return Fill(UnquotedChar('.'), b_size)
end
# TODO: Use `Base.to_indices`.
@inline function (f::GetUnstoredBlockShow)(
a::AbstractArray{<:Any,N}, I::CartesianIndex{N}
) where {N}
return f(a, Tuple(I)...)
end

# TODO: Make this an `@interface ::AbstractBlockSparseArrayInterface` function
# once we delete the hacky `Base.show` definitions in `BlockSparseArraysTensorAlgebraExt`.
function Base.show(io::IO, mime::MIME"text/plain", a::AnyAbstractBlockSparseArray)
summary(io, a)
isempty(a) && return nothing
print(io, ":")
println(io)
a′ = ReplacedUnstoredBlockSparseArray(a, GetUnstoredBlockShow(axes(a)))
@allowscalar Base.print_array(io, a′)
return nothing
end
2 changes: 1 addition & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ arrayts = (Array, JLArray)
a = BlockSparseMatrix{elt,arrayt{elt,2}}([2, 2], [2, 2])
@allowscalar a[1, 2] = 12
@test sprint(show, "text/plain", a) ==
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ . .\n $(zero(eltype(a))) $(zero(eltype(a))) │ . .\n ───────────┼──────\n . .. .\n . .. ."
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅ \n ───────────┼──────────\n ⋅ ⋅ \n ⋅ ⋅ "
end
end
@testset "TypeParameterAccessors.position" begin
Expand Down
Loading