diff --git a/conn.go b/conn.go index da4ff9de..9f976f66 100644 --- a/conn.go +++ b/conn.go @@ -828,6 +828,7 @@ func decideColumnFormats( } func (cn *conn) prepareTo(q, stmtName string) *stmt { + var err error st := &stmt{cn: cn, name: stmtName} b := cn.writeBuf('P') @@ -839,13 +840,30 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt { b.byte('S') b.string(st.name) - b.next('S') + if stmtName != "" { + b.next('S') // sync + } else { + b.next('H') // flush + } cn.send(b) - cn.readParseResponse() - st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() + if err := cn.readParseResponse(); err != nil { + cn.send(cn.writeBuf('S')) // sync + cn.readReadyForQuery() + panic(err) + } + + st.paramTyps, st.colNames, st.colTyps, err = cn.readStatementDescribeResponse() + if err != nil { + cn.send(cn.writeBuf('S')) // sync + cn.readReadyForQuery() + panic(err) + } st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) - cn.readReadyForQuery() + if stmtName != "" { + cn.readReadyForQuery() + } + return st } @@ -906,7 +924,11 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { if cn.binaryParameters { cn.sendBinaryModeQuery(query, args) - cn.readParseResponse() + if err := cn.readParseResponse(); err != nil { + cn.readReadyForQuery() + panic(err) + } + cn.readBindResponse() rows := &rows{cn: cn} rows.rowsHeader = cn.readPortalDescribeResponse() @@ -939,7 +961,11 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err if cn.binaryParameters { cn.sendBinaryModeQuery(query, args) - cn.readParseResponse() + if err := cn.readParseResponse(); err != nil { + cn.readReadyForQuery() + panic(err) + } + cn.readBindResponse() cn.readPortalDescribeResponse() cn.postExecuteWorkaround() @@ -1819,25 +1845,25 @@ func (cn *conn) processBackendKeyData(r *readBuf) { cn.secretKey = r.int32() } -func (cn *conn) readParseResponse() { +func (cn *conn) readParseResponse() (err error) { t, r := cn.recv1() switch t { case '1': - return case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) + err = parseError(r) default: cn.err.set(driver.ErrBadConn) errorf("unexpected Parse response %q", t) } + + return } func (cn *conn) readStatementDescribeResponse() ( paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc, + err error, ) { for { t, r := cn.recv1() @@ -1849,14 +1875,13 @@ func (cn *conn) readStatementDescribeResponse() ( paramTyps[i] = r.oid() } case 'n': - return paramTyps, nil, nil + return case 'T': colNames, colTyps = parseStatementRowDescribe(r) - return paramTyps, colNames, colTyps + return case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) + err = parseError(r) + return default: cn.err.set(driver.ErrBadConn) errorf("unexpected Describe statement response %q", t)