diff --git a/conn.go b/conn.go index c5371cf6..8c55d9e1 100644 --- a/conn.go +++ b/conn.go @@ -69,6 +69,12 @@ func (s transactionStatus) String() string { panic("not reached") } +// Any failure in communication should be raised as a communicationError to +// allow errRecover to know what happened. +type communicationError struct { + error +} + type conn struct { c net.Conn buf *bufio.Reader @@ -153,7 +159,7 @@ func Open(name string) (_ driver.Conn, err error) { c, err := net.Dial(network(o)) if err != nil { - return nil, err + return nil, communicationError{err} } cn := &conn{c: c} @@ -556,7 +562,7 @@ func (cn *conn) send(m *writeBuf) { _, err := cn.c.Write(*m) if err != nil { - panic(err) + panic(communicationError{err}) } } @@ -574,7 +580,7 @@ func (cn *conn) recvMessage() (byte, *readBuf, error) { x := cn.scratch[:5] _, err := io.ReadFull(cn.buf, x) if err != nil { - return 0, nil, err + return 0, nil, communicationError{err} } t := x[0] @@ -588,7 +594,7 @@ func (cn *conn) recvMessage() (byte, *readBuf, error) { } _, err = io.ReadFull(cn.buf, y) if err != nil { - return 0, nil, err + return 0, nil, communicationError{err} } return t, (*readBuf)(&y), nil diff --git a/conn_test.go b/conn_test.go index 2797b1b2..adb655bb 100644 --- a/conn_test.go +++ b/conn_test.go @@ -412,7 +412,7 @@ func TestBadConn(t *testing.T) { func() { defer errRecover(&err) - panic(io.EOF) + panic(communicationError{io.EOF}) }() if err != driver.ErrBadConn { diff --git a/error.go b/error.go index 1f7e9e79..c1346601 100644 --- a/error.go +++ b/error.go @@ -3,8 +3,6 @@ package pq import ( "database/sql/driver" "fmt" - "io" - "net" "runtime" ) @@ -452,14 +450,10 @@ func errRecover(err *error) { } else { *err = v } - case *net.OpError: + case communicationError: *err = driver.ErrBadConn case error: - if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { - *err = driver.ErrBadConn - } else { - *err = v - } + *err = v default: panic(fmt.Sprintf("unknown error: %#v", e))