Skip to content

Commit c723ba7

Browse files
committed
lapack/netlib: add Dpbtrf and Dpbtrs
Also move the banded matrix conversion code to the lapack/netlib package because that's where any matrix conversion should be done. The lapacke package should accept exactly what LAPACKE accepts which means that unfortunately for banded matrices there will be the inevitable overhead of two conversions: one from BLAS (Gonum) row-major format to LAPACKE row-major format for banded matrices and one inside LAPACKE to the FORTRAN column-major banded format. In case of Dpbtrf the inverse conversion must be performed for the factored matrix. lapack/netlib: add Dpbtrs conv
1 parent d71f404 commit c723ba7

File tree

6 files changed

+276
-161
lines changed

6 files changed

+276
-161
lines changed

lapack/lapacke/internal/conv/conv.go

Lines changed: 0 additions & 19 deletions
This file was deleted.

lapack/lapacke/internal/conv/pb_test.go

Lines changed: 0 additions & 123 deletions
This file was deleted.

lapack/lapacke/internal/conv/pb.go renamed to lapack/netlib/conv.go

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,32 @@
22
// Use of this source code is governed by a BSD-style
33
// license that can be found in the LICENSE file.
44

5-
package conv
5+
package netlib
66

7-
// DpbToColMajor converts a symmetric band matrix A in CBLAS row-major layout
8-
// to FORTRAN column-major layout and stores the result in B.
7+
import "gonum.org/v1/gonum/blas"
8+
9+
func DpbToRowMajor(uplo byte, n, kd int, a []float64, lda int, b []float64, ldb int) {
10+
if uplo == 'U' {
11+
for j := 0; j < n; j++ {
12+
for ib := max(0, kd-j); ib < kd+1; ib++ {
13+
i := j - kd + ib // Row index in the full matrix
14+
b[i*ldb+kd-ib] = a[ib+j*lda]
15+
}
16+
}
17+
} else {
18+
for j := 0; j < n; j++ {
19+
for ib := 0; ib < min(n-j, kd+1); ib++ {
20+
i := j + ib // Row index in the full matrix
21+
b[i*ldb+kd-ib] = a[ib+j*lda]
22+
}
23+
}
24+
}
25+
}
26+
27+
// convDpbToLapacke converts a symmetric band matrix A in CBLAS row-major layout
28+
// to LAPACKE row-major layout and stores the result in B.
929
//
10-
// For example, when n = 6, kd = 2 and uplo == 'U', DpbToColMajor
11-
// converts
30+
// For example, when n = 6, kd = 2 and uplo == 'U', convDpbToLapacke converts
1231
// A = a00 a01 a02
1332
// a11 a12 a13
1433
// a22 a23 a24
@@ -22,9 +41,9 @@ package conv
2241
// * a01 a12 a23 a34 a45
2342
// a00 a11 a22 a33 a44 a55
2443
// stored in a slice as
25-
// b = [* * a00 * a01 a11 a02 a12 a22 a13 a23 a33 a24 a34 a44 a35 a45 a55]
44+
// b = [* * a02 a13 a24 a35 * a01 a12 a23 a34 a45 a00 a11 a22 a33 a44 a55]
2645
//
27-
// When n = 6, kd = 2 and uplo == 'L', DpbToColMajor converts
46+
// When n = 6, kd = 2 and uplo == 'L', convDpbToLapacke converts
2847
// A = * * a00
2948
// * a10 a11
3049
// a20 a21 a22
@@ -38,43 +57,43 @@ package conv
3857
// a10 a21 a32 a43 a54 *
3958
// a20 a31 a42 a53 * *
4059
// stored in a slice as
41-
// b = [a00 a10 a20 a11 a21 a31 a22 a32 a42 a33 a43 a53 a44 a54 * a55 * *]
60+
// b = [a00 a11 a22 a33 a44 a55 a10 a21 a32 a43 a54 * a20 a31 a42 a53 * * ]
4261
//
4362
// In these example elements marked as * are not referenced.
44-
func DpbToColMajor(uplo byte, n, kd int, a []float64, lda int, b []float64, ldb int) {
45-
if uplo == 'U' {
63+
func convDpbToLapacke(uplo blas.Uplo, n, kd int, a []float64, lda int, b []float64, ldb int) {
64+
if uplo == blas.Upper {
4665
for i := 0; i < n; i++ {
4766
for jb := 0; jb < min(n-i, kd+1); jb++ {
4867
j := i + jb // Column index in the full matrix
49-
b[kd-jb+j*ldb] = a[i*lda+jb]
68+
b[(kd-jb)*ldb+j] = a[i*lda+jb]
5069
}
5170
}
5271
} else {
5372
for i := 0; i < n; i++ {
5473
for jb := max(0, kd-i); jb < kd+1; jb++ {
5574
j := i - kd + jb // Column index in the full matrix
56-
b[kd-jb+j*ldb] = a[i*lda+jb]
75+
b[(kd-jb)*ldb+j] = a[i*lda+jb]
5776
}
5877
}
5978
}
6079
}
6180

62-
// DpbToRowMajor converts a symmetric band matrix A in FORTRAN column-major
63-
// layout to CBLAS row-major layout and stores the result in B. In other words,
64-
// it performs the inverse conversion to DpbToColMajor.
65-
func DpbToRowMajor(uplo byte, n, kd int, a []float64, lda int, b []float64, ldb int) {
66-
if uplo == 'U' {
81+
// convDpbToGonum converts a symmetric band matrix A in LAPACKE row-major layout
82+
// to CBLAS row-major layout and stores the result in B. In other words, it
83+
// performs the inverse conversion to DpbToColMajor.
84+
func convDpbToGonum(uplo blas.Uplo, n, kd int, a []float64, lda int, b []float64, ldb int) {
85+
if uplo == blas.Upper {
6786
for j := 0; j < n; j++ {
6887
for ib := max(0, kd-j); ib < kd+1; ib++ {
6988
i := j - kd + ib // Row index in the full matrix
70-
b[i*ldb+kd-ib] = a[ib+j*lda]
89+
b[i*ldb+kd-ib] = a[ib*lda+j]
7190
}
7291
}
7392
} else {
7493
for j := 0; j < n; j++ {
7594
for ib := 0; ib < min(n-j, kd+1); ib++ {
7695
i := j + ib // Row index in the full matrix
77-
b[i*ldb+kd-ib] = a[ib+j*lda]
96+
b[i*ldb+kd-ib] = a[ib*lda+j]
7897
}
7998
}
8099
}

lapack/netlib/conv_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// Copyright ©2019 The Gonum Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package netlib
6+
7+
import (
8+
"fmt"
9+
"testing"
10+
11+
"golang.org/x/exp/rand"
12+
13+
"gonum.org/v1/gonum/blas"
14+
"gonum.org/v1/gonum/floats"
15+
)
16+
17+
func TestConvDpb(t *testing.T) {
18+
for ti, test := range []struct {
19+
uplo blas.Uplo
20+
n, kd int
21+
a, b []float64
22+
}{
23+
{
24+
uplo: blas.Upper,
25+
n: 6,
26+
kd: 2,
27+
a: []float64{
28+
1, 2, 3, // 1. row
29+
4, 5, 6,
30+
7, 8, 9,
31+
10, 11, 12,
32+
13, 14, -1,
33+
15, -1, -1, // 6. row
34+
},
35+
b: []float64{
36+
-1, -1, 3, 6, 9, 12, // 2. super-diagonal
37+
-1, 2, 5, 8, 11, 14,
38+
1, 4, 7, 10, 13, 15, // main diagonal
39+
},
40+
},
41+
{
42+
uplo: blas.Lower,
43+
n: 6,
44+
kd: 2,
45+
a: []float64{
46+
-1, -1, 1, // 1. row
47+
-1, 2, 3,
48+
4, 5, 6,
49+
7, 8, 9,
50+
10, 11, 12,
51+
13, 14, 15, // 6. row
52+
},
53+
b: []float64{
54+
1, 3, 6, 9, 12, 15, // main diagonal
55+
2, 5, 8, 11, 14, -1,
56+
4, 7, 10, 13, -1, -1, // 2. sub-diagonal
57+
},
58+
},
59+
} {
60+
uplo := test.uplo
61+
n := test.n
62+
kd := test.kd
63+
name := fmt.Sprintf("Case %v (uplo=%c,n=%v,kd=%v)", ti, uplo, n, kd)
64+
65+
a := make([]float64, len(test.a))
66+
copy(a, test.a)
67+
lda := kd + 1
68+
69+
got := make([]float64, len(test.b))
70+
for i := range got {
71+
got[i] = -1
72+
}
73+
ldb := max(1, n)
74+
75+
convDpbToLapacke(uplo, n, kd, a, lda, got, ldb)
76+
if !floats.Equal(test.a, a) {
77+
t.Errorf("%v: unexpected modification of A in conversion to LAPACKE row-major", name)
78+
}
79+
if !floats.Equal(test.b, got) {
80+
t.Errorf("%v: unexpected conversion to LAPACKE row-major;\ngot %v\nwant %v", name, got, test.b)
81+
}
82+
83+
b := make([]float64, len(test.b))
84+
copy(b, test.b)
85+
86+
got = make([]float64, len(test.a))
87+
for i := range got {
88+
got[i] = -1
89+
}
90+
91+
convDpbToGonum(uplo, n, kd, b, ldb, got, lda)
92+
if !floats.Equal(test.b, b) {
93+
t.Errorf("%v: unexpected modification of B in conversion to Gonum row-major", name)
94+
}
95+
if !floats.Equal(test.a, got) {
96+
t.Errorf("%v: unexpected conversion to Gonum row-major;\ngot %v\nwant %v", name, got, test.b)
97+
}
98+
}
99+
100+
rnd := rand.New(rand.NewSource(1))
101+
for _, n := range []int{0, 1, 2, 3, 4, 5, 10} {
102+
for _, kd := range []int{0, (n + 1) / 4, (3*n - 1) / 4, (5*n + 1) / 4} {
103+
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
104+
for _, ldextra := range []int{0, 3} {
105+
name := fmt.Sprintf("uplo=%c,n=%v,kd=%v", uplo, n, kd)
106+
107+
lda := kd + 1 + ldextra
108+
a := make([]float64, n*lda)
109+
for i := range a {
110+
a[i] = rnd.NormFloat64()
111+
}
112+
aCopy := make([]float64, len(a))
113+
copy(aCopy, a)
114+
115+
ldb := max(1, n) + ldextra
116+
b := make([]float64, (kd+1)*ldb)
117+
for i := range b {
118+
b[i] = rnd.NormFloat64()
119+
}
120+
121+
convDpbToLapacke(uplo, n, kd, a, lda, b, ldb)
122+
convDpbToGonum(uplo, n, kd, b, ldb, a, lda)
123+
124+
if !floats.Equal(a, aCopy) {
125+
t.Errorf("%v: conversion does not roundtrip", name)
126+
}
127+
}
128+
}
129+
}
130+
}
131+
}

0 commit comments

Comments
 (0)