From bb8ad5410bd6bad88575429510ef7cf2388e107e Mon Sep 17 00:00:00 2001 From: Kei Kamikawa Date: Mon, 18 Oct 2021 17:29:41 +0900 Subject: [PATCH] fixed to use context.Context in TLS handshake --- client.go | 16 ++-------------- tls_handshake.go | 21 +++++++++++++++++++++ tls_handshake_116.go | 21 +++++++++++++++++++++ trace.go | 6 ++++-- trace_17.go | 5 +++-- 5 files changed, 51 insertions(+), 18 deletions(-) create mode 100644 tls_handshake.go create mode 100644 tls_handshake_116.go diff --git a/client.go b/client.go index c4b62fbc..4960a91b 100644 --- a/client.go +++ b/client.go @@ -314,9 +314,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h var err error if trace != nil { - err = doHandshakeWithTrace(trace, tlsConn, cfg) + err = doHandshakeWithTrace(ctx, trace, tlsConn, cfg) } else { - err = doHandshake(tlsConn, cfg) + err = doHandshake(ctx, tlsConn, cfg) } if err != nil { @@ -381,15 +381,3 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h netConn = nil // to avoid close in defer. return conn, resp, nil } - -func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error { - if err := tlsConn.Handshake(); err != nil { - return err - } - if !cfg.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { - return err - } - } - return nil -} diff --git a/tls_handshake.go b/tls_handshake.go new file mode 100644 index 00000000..a62b68cc --- /dev/null +++ b/tls_handshake.go @@ -0,0 +1,21 @@ +//go:build go1.17 +// +build go1.17 + +package websocket + +import ( + "context" + "crypto/tls" +) + +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.HandshakeContext(ctx); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/tls_handshake_116.go b/tls_handshake_116.go new file mode 100644 index 00000000..e1b2b44f --- /dev/null +++ b/tls_handshake_116.go @@ -0,0 +1,21 @@ +//go:build !go1.17 +// +build !go1.17 + +package websocket + +import ( + "context" + "crypto/tls" +) + +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.Handshake(); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/trace.go b/trace.go index 834f122a..4244e82f 100644 --- a/trace.go +++ b/trace.go @@ -1,17 +1,19 @@ +//go:build go1.8 // +build go1.8 package websocket import ( + "context" "crypto/tls" "net/http/httptrace" ) -func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { +func doHandshakeWithTrace(ctx context.Context, trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { if trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - err := doHandshake(tlsConn, cfg) + err := doHandshake(ctx, tlsConn, cfg) if trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) } diff --git a/trace_17.go b/trace_17.go index 77d05a0b..b66fb3b4 100644 --- a/trace_17.go +++ b/trace_17.go @@ -1,3 +1,4 @@ +//go:build !go1.8 // +build !go1.8 package websocket @@ -7,6 +8,6 @@ import ( "net/http/httptrace" ) -func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { - return doHandshake(tlsConn, cfg) +func doHandshakeWithTrace(ctx context.Context, trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { + return doHandshake(ctx, tlsConn, cfg) }