diff --git a/Project.toml b/Project.toml index 7d519912..2fada9b9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseDiffTools" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" authors = ["Pankaj Mishra ", "Chris Rackauckas "] -version = "2.5.1" +version = "2.5.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/coloring/matrix2graph.jl b/src/coloring/matrix2graph.jl index ef19ca19..fcf4050b 100644 --- a/src/coloring/matrix2graph.jl +++ b/src/coloring/matrix2graph.jl @@ -16,8 +16,7 @@ end """ _rows_by_cols(rows_index,cols_index) -Returns a vector of columns where each column contains -a vector of its row indices. +Returns a vector of columns where each column contains a vector of its row indices. """ function _rows_by_cols(rows_index, cols_index) return _cols_by_rows(cols_index, rows_index) @@ -26,15 +25,14 @@ end """ matrix2graph(sparse_matrix, [partition_by_rows::Bool=true]) -A utility function to generate a graph from input -sparse matrix, columns are represented with vertices -and 2 vertices are connected with an edge only if -the two columns are mutually orthogonal. +A utility function to generate a graph from input sparse matrix, columns are represented +with vertices and 2 vertices are connected with an edge only if the two columns are mutually +orthogonal. -Note that the sparsity pattern is defined by structural nonzeroes, ie includes -explicitly stored zeros. +Note that the sparsity pattern is defined by structural nonzeroes, ie includes explicitly +stored zeros. """ -function matrix2graph(sparse_matrix::SparseMatrixCSC{<:Number, Int}, +function matrix2graph(sparse_matrix::AbstractSparseMatrix{<:Number}, partition_by_rows::Bool = true) (rows_index, cols_index, _) = findnz(sparse_matrix) @@ -43,7 +41,7 @@ function matrix2graph(sparse_matrix::SparseMatrixCSC{<:Number, Int}, num_vtx = partition_by_rows ? nrows : ncols - inner = SimpleGraph(num_vtx) + inner = SimpleGraph{promote_type(eltype(rows_index), eltype(cols_index))}(num_vtx) if partition_by_rows rows_by_cols = _rows_by_cols(rows_index, cols_index)