From 6ff5e9631d7c2d8d733d5af9ea95c82268bd7a38 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 01:15:15 +0200 Subject: [PATCH 1/8] Reworking handshake() to support unit testing. - Split handshake() into separate functions - Created ClientIO interface to enable mocking - Added unit tests for clientInit() --- client.go | 140 ++++++++++++++++++++++++++++++++++--------------- client_test.go | 45 +++++++++++++++- 2 files changed, 141 insertions(+), 44 deletions(-) diff --git a/client.go b/client.go index 91402c9..77a60eb 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,9 @@ -// Package vnc implements a VNC client. -// -// References: -// [PROTOCOL]: http://tools.ietf.org/html/rfc6143 +/* +Package vnc implements a VNC client. + +References: + [PROTOCOL]: http://tools.ietf.org/html/rfc6143 +*/ package vnc import ( @@ -13,9 +15,11 @@ import ( "unicode" ) +// The ClientConn type holds client connection information. type ClientConn struct { c net.Conn config *ClientConfig + rw ClientIO // If the pixel format uses a color map, then this is the color // map that is used. This should not be modified directly, since @@ -71,8 +75,25 @@ func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { c: c, config: cfg, } + conn.rw = NewClientIOReaderWriter(c, c) - if err := conn.handshake(); err != nil { + if err := conn.protocolVersionHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.securityHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.securityResultHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.clientInit(); err != nil { + conn.Close() + return nil, err + } + if err := conn.serverInit(); err != nil { conn.Close() return nil, err } @@ -279,7 +300,13 @@ func (c *ClientConn) SetPixelFormat(format *PixelFormat) error { return nil } -const pvLen = 12 // ProtocolVersion message length. +const ( + pvLen = 12 // ProtocolVersion message length. + + // Supported protocol versions. + PROTO_VERS_UNSUP = "UNSUP" + PROTO_VERS_3_8 = "003.008" +) func parseProtocolVersion(pv []byte) (uint, uint, error) { var major, minor uint @@ -299,42 +326,48 @@ func parseProtocolVersion(pv []byte) (uint, uint, error) { return major, minor, nil } -func (c *ClientConn) handshake() error { +// protocolVersionHandshake implements §7.1.1 ProtocolVersion Handshake. +func (c *ClientConn) protocolVersionHandshake() error { var protocolVersion [pvLen]byte - // 7.1.1, read the ProtocolVersion message sent by the server. + // Read the ProtocolVersion message sent by the server. if _, err := io.ReadFull(c.c, protocolVersion[:]); err != nil { return err } - maxMajor, maxMinor, err := parseProtocolVersion(protocolVersion[:]) + major, minor, err := parseProtocolVersion(protocolVersion[:]) if err != nil { return err } - if maxMajor < 3 { - return fmt.Errorf("unsupported major version, less than 3: %d", maxMajor) + pv := PROTO_VERS_UNSUP + if major == 3 && minor >= 8 { + pv = PROTO_VERS_3_8 } - if maxMinor < 8 { - return fmt.Errorf("unsupported minor version, less than 8: %d", maxMinor) + if pv == PROTO_VERS_UNSUP { + return fmt.Errorf("unsupported server ProtocolVersion '%v'", string(protocolVersion[:])) } // Respond with the version we will support - if _, err = c.c.Write([]byte("RFB 003.008\n")); err != nil { + if _, err = c.c.Write([]byte("RFB " + pv + "\n")); err != nil { return err } - // 7.1.2 Security Handshake from server + return nil +} + +// securityHandshake implements §7.1.2 Security Handshake. +func (c *ClientConn) securityHandshake() error { var numSecurityTypes uint8 - if err = binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil { + + if err := binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil { return err } - if numSecurityTypes == 0 { return fmt.Errorf("no security types: %s", c.readErrorReason()) } securityTypes := make([]uint8, numSecurityTypes) - if err = binary.Read(c.c, binary.BigEndian, &securityTypes); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &securityTypes); err != nil { return err } @@ -354,64 +387,65 @@ FindAuth: } } } - if auth == nil { return fmt.Errorf("no suitable auth schemes found. server supported: %#v", securityTypes) } // Respond back with the security type we'll use - if err = binary.Write(c.c, binary.BigEndian, auth.SecurityType()); err != nil { + if err := binary.Write(c.c, binary.BigEndian, auth.SecurityType()); err != nil { return err } - - if err = auth.Handshake(c.c); err != nil { + if err := auth.Handshake(c.c); err != nil { return err } + return nil +} - // 7.1.3 SecurityResult Handshake +// securityResultHandshake implements §7.1.3 SecurityResult Handshake. +func (c *ClientConn) securityResultHandshake() error { var securityResult uint32 - if err = binary.Read(c.c, binary.BigEndian, &securityResult); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &securityResult); err != nil { return err } - if securityResult == 1 { return fmt.Errorf("security handshake failed: %s", c.readErrorReason()) } + return nil +} - // 7.3.1 ClientInit - var sharedFlag uint8 = 1 - if c.config.Exclusive { - sharedFlag = 0 +// clientInit implements §7.3.1 ClientInit. +func (c *ClientConn) clientInit() error { + var sharedFlag uint8 + if !c.config.Exclusive { + sharedFlag = 1 } - - if err = binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { + //if err := binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { + if err := c.rw.Write(sharedFlag); err != nil { return err } + return nil +} - // 7.3.2 ServerInit - if err = binary.Read(c.c, binary.BigEndian, &c.FrameBufferWidth); err != nil { +// serverInit implements §7.3.2 ServerInit. +func (c *ClientConn) serverInit() error { + if err := binary.Read(c.c, binary.BigEndian, &c.FrameBufferWidth); err != nil { return err } - - if err = binary.Read(c.c, binary.BigEndian, &c.FrameBufferHeight); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &c.FrameBufferHeight); err != nil { return err } - - // Read the pixel format - if err = readPixelFormat(c.c, &c.PixelFormat); err != nil { + if err := readPixelFormat(c.c, &c.PixelFormat); err != nil { return err } var nameLength uint32 - if err = binary.Read(c.c, binary.BigEndian, &nameLength); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &nameLength); err != nil { return err } - nameBytes := make([]uint8, nameLength) - if err = binary.Read(c.c, binary.BigEndian, &nameBytes); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &nameBytes); err != nil { return err } - c.DesktopName = string(nameBytes) return nil @@ -480,3 +514,25 @@ func (c *ClientConn) readErrorReason() string { return string(reason) } + +type ClientIO interface { + Read(data interface{}) error + Write(data interface{}) error +} + +type ClientIOReaderWriter struct { + reader io.Reader + writer io.Writer +} + +func NewClientIOReaderWriter(r io.Reader, w io.Writer) ClientIOReaderWriter { + return ClientIOReaderWriter{r, w} +} + +func (rw ClientIOReaderWriter) Read(data interface{}) error { + return binary.Read(rw.reader, binary.BigEndian, data) +} + +func (rw ClientIOReaderWriter) Write(data interface{}) error { + return binary.Write(rw.writer, binary.BigEndian, data) +} diff --git a/client_test.go b/client_test.go index 31591b4..c27082d 100644 --- a/client_test.go +++ b/client_test.go @@ -40,7 +40,7 @@ func TestClient_LowMajorVersion(t *testing.T) { t.Fatal("error expected") } - if err.Error() != "unsupported major version, less than 3: 2" { + if err.Error() != "unsupported server ProtocolVersion 'RFB 002.009\n'" { t.Fatalf("unexpected error: %s", err) } } @@ -56,7 +56,7 @@ func TestClient_LowMinorVersion(t *testing.T) { t.Fatal("error expected") } - if err.Error() != "unsupported minor version, less than 8: 7" { + if err.Error() != "unsupported server ProtocolVersion 'RFB 003.007\n'" { t.Fatalf("unexpected error: %s", err) } } @@ -93,3 +93,44 @@ func TestParseProtocolVersion(t *testing.T) { } } } + +type MockClientIOReaderWriter struct { + i, o interface{} +} + +func (rw *MockClientIOReaderWriter) Read(data interface{}) error { + rw.i = data + return nil +} + +func (rw *MockClientIOReaderWriter) Write(data interface{}) error { + rw.o = data + return nil +} + +func TestClientInit(t *testing.T) { + var err error + + tests := []struct { + exclusive bool + shared uint8 + }{ + {true, 0}, + {false, 1}, + } + + rw := &MockClientIOReaderWriter{} + cfg := &ClientConfig{} + conn := &ClientConn{config: cfg, rw: rw} + + for _, tt := range tests { + cfg.Exclusive = tt.exclusive + err = conn.clientInit() + if err != nil { + t.Fatalf("clientInit() error %v", err) + } + if rw.o != uint8(tt.shared) { + t.Errorf("clientInit() got = %v, want %v", rw.o, tt.shared) + } + } +} From 9fd5d84baa5a41af30f63101b82bf6860aeb1959 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 01:23:02 +0200 Subject: [PATCH 2/8] moving the client protocol version const lower --- client.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index 77a60eb..7c7cfa7 100644 --- a/client.go +++ b/client.go @@ -300,13 +300,7 @@ func (c *ClientConn) SetPixelFormat(format *PixelFormat) error { return nil } -const ( - pvLen = 12 // ProtocolVersion message length. - - // Supported protocol versions. - PROTO_VERS_UNSUP = "UNSUP" - PROTO_VERS_3_8 = "003.008" -) +const pvLen = 12 // ProtocolVersion message length. func parseProtocolVersion(pv []byte) (uint, uint, error) { var major, minor uint @@ -326,6 +320,12 @@ func parseProtocolVersion(pv []byte) (uint, uint, error) { return major, minor, nil } +const ( + // Client ProtocolVersions. + PROTO_VERS_UNSUP = "UNSUPPORTED" + PROTO_VERS_3_8 = "RFB 003.008\n" +) + // protocolVersionHandshake implements §7.1.1 ProtocolVersion Handshake. func (c *ClientConn) protocolVersionHandshake() error { var protocolVersion [pvLen]byte @@ -348,7 +348,7 @@ func (c *ClientConn) protocolVersionHandshake() error { } // Respond with the version we will support - if _, err = c.c.Write([]byte("RFB " + pv + "\n")); err != nil { + if _, err = c.c.Write([]byte(pv)); err != nil { return err } From a56035f1aef7205c0b1f87dbd1ddd99ebe1f1489 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 16:18:23 +0200 Subject: [PATCH 3/8] Reworked and simplified mocking, and now test securityResultHandshake(). --- client.go | 59 ++++++++++++++++---------------- client_test.go | 92 +++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 102 insertions(+), 49 deletions(-) diff --git a/client.go b/client.go index 7c7cfa7..c14a8ca 100644 --- a/client.go +++ b/client.go @@ -19,7 +19,6 @@ import ( type ClientConn struct { c net.Conn config *ClientConfig - rw ClientIO // If the pixel format uses a color map, then this is the color // map that is used. This should not be modified directly, since @@ -75,7 +74,6 @@ func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { c: c, config: cfg, } - conn.rw = NewClientIOReaderWriter(c, c) if err := conn.protocolVersionHandshake(); err != nil { conn.Close() @@ -89,7 +87,7 @@ func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { conn.Close() return nil, err } - if err := conn.clientInit(); err != nil { + if _, err := conn.clientInit(); err != nil { conn.Close() return nil, err } @@ -363,7 +361,11 @@ func (c *ClientConn) securityHandshake() error { return err } if numSecurityTypes == 0 { - return fmt.Errorf("no security types: %s", c.readErrorReason()) + reason, err := c.readErrorReason() + if err != nil { + return err + } + return fmt.Errorf("no security types: %s", reason) } securityTypes := make([]uint8, numSecurityTypes) @@ -404,26 +406,33 @@ FindAuth: // securityResultHandshake implements §7.1.3 SecurityResult Handshake. func (c *ClientConn) securityResultHandshake() error { var securityResult uint32 + if err := binary.Read(c.c, binary.BigEndian, &securityResult); err != nil { return err } if securityResult == 1 { - return fmt.Errorf("security handshake failed: %s", c.readErrorReason()) + reason, err := c.readErrorReason() + if err != nil { + return err + } + return NewVNCError(fmt.Sprintf("SecurityResult handshake failed: %s", reason)) } + return nil } // clientInit implements §7.3.1 ClientInit. -func (c *ClientConn) clientInit() error { +func (c *ClientConn) clientInit() (uint8, error) { var sharedFlag uint8 + if !c.config.Exclusive { sharedFlag = 1 } - //if err := binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { - if err := c.rw.Write(sharedFlag); err != nil { - return err + if err := binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { + return 0, err } - return nil + + return sharedFlag, nil } // serverInit implements §7.3.2 ServerInit. @@ -501,38 +510,28 @@ func (c *ClientConn) mainLoop() { } } -func (c *ClientConn) readErrorReason() string { +func (c *ClientConn) readErrorReason() (string, error) { var reasonLen uint32 if err := binary.Read(c.c, binary.BigEndian, &reasonLen); err != nil { - return "" + return "", err } reason := make([]uint8, reasonLen) if err := binary.Read(c.c, binary.BigEndian, &reason); err != nil { - return "" + return "", err } - return string(reason) -} - -type ClientIO interface { - Read(data interface{}) error - Write(data interface{}) error -} - -type ClientIOReaderWriter struct { - reader io.Reader - writer io.Writer + return string(reason), nil } -func NewClientIOReaderWriter(r io.Reader, w io.Writer) ClientIOReaderWriter { - return ClientIOReaderWriter{r, w} +type vncError struct { + s string } -func (rw ClientIOReaderWriter) Read(data interface{}) error { - return binary.Read(rw.reader, binary.BigEndian, data) +func NewVNCError(s string) error { + return &vncError{s} } -func (rw ClientIOReaderWriter) Write(data interface{}) error { - return binary.Write(rw.writer, binary.BigEndian, data) +func (e vncError) Error() string { + return e.s } diff --git a/client_test.go b/client_test.go index c27082d..2603d8a 100644 --- a/client_test.go +++ b/client_test.go @@ -1,9 +1,13 @@ package vnc import ( + "bytes" + "encoding/binary" "fmt" "net" + "reflect" "testing" + "time" ) func newMockServer(t *testing.T, version string) string { @@ -94,23 +98,47 @@ func TestParseProtocolVersion(t *testing.T) { } } -type MockClientIOReaderWriter struct { - i, o interface{} -} +func TestSecurityResultHandshake(t *testing.T) { + tests := []struct { + result uint32 + ok bool + reason string + }{ + {0, true, ""}, + {1, false, "SecurityResult error"}, + } -func (rw *MockClientIOReaderWriter) Read(data interface{}) error { - rw.i = data - return nil -} + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } -func (rw *MockClientIOReaderWriter) Write(data interface{}) error { - rw.o = data - return nil + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, tt.result); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.reason))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.reason)); err != nil { + t.Fatal(err) + } + + err := conn.securityResultHandshake() + if err == nil && !tt.ok { + t.Fatalf("securityResultHandshake() expected error for result %v", tt.result) + } + if err != nil { + if verr, ok := err.(*vncError); !ok { + t.Errorf("securityResultHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + } } func TestClientInit(t *testing.T) { - var err error - tests := []struct { exclusive bool shared uint8 @@ -119,18 +147,44 @@ func TestClientInit(t *testing.T) { {false, 1}, } - rw := &MockClientIOReaderWriter{} - cfg := &ClientConfig{} - conn := &ClientConn{config: cfg, rw: rw} + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } for _, tt := range tests { - cfg.Exclusive = tt.exclusive - err = conn.clientInit() + mockConn.Reset() + conn.config.Exclusive = tt.exclusive + + shared, err := conn.clientInit() if err != nil { t.Fatalf("clientInit() error %v", err) } - if rw.o != uint8(tt.shared) { - t.Errorf("clientInit() got = %v, want %v", rw.o, tt.shared) + if shared != tt.shared { + t.Errorf("clientInit() got = %v, want %v", shared, tt.shared) } } } + +// MockConn implements the net.Conn interface. +type MockConn struct { + b bytes.Buffer +} + +func (m *MockConn) Read(b []byte) (int, error) { + return m.b.Read(b) +} +func (m *MockConn) Write(b []byte) (int, error) { + return m.b.Write(b) +} +func (m *MockConn) Close() error { return nil } +func (m *MockConn) LocalAddr() net.Addr { return nil } +func (m *MockConn) RemoteAddr() net.Addr { return nil } +func (m *MockConn) SetDeadline(t time.Time) error { return nil } +func (m *MockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *MockConn) SetWriteDeadline(t time.Time) error { return nil } +// Implement additional buffer.Buffer functions. +func (m *MockConn) Reset() { + m.b.Reset() +} From 533f16f8c597412ec98a284c6a85250aa7213e71 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 16:24:31 +0200 Subject: [PATCH 4/8] renamed vncError to VNCError to make it externally visible --- client.go | 8 +++++--- client_test.go | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index c14a8ca..9f81973 100644 --- a/client.go +++ b/client.go @@ -524,14 +524,16 @@ func (c *ClientConn) readErrorReason() (string, error) { return string(reason), nil } -type vncError struct { +// VNCError implements error interface. +type VNCError struct { s string } +// NewVNCError returns a custom VNCError error. func NewVNCError(s string) error { - return &vncError{s} + return &VNCError{s} } -func (e vncError) Error() string { +func (e VNCError) Error() string { return e.s } diff --git a/client_test.go b/client_test.go index 2603d8a..6f6cec3 100644 --- a/client_test.go +++ b/client_test.go @@ -131,7 +131,7 @@ func TestSecurityResultHandshake(t *testing.T) { t.Fatalf("securityResultHandshake() expected error for result %v", tt.result) } if err != nil { - if verr, ok := err.(*vncError); !ok { + if verr, ok := err.(*VNCError); !ok { t.Errorf("securityResultHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) } } From d1f6b14a56b4c46ff31dfa9408cb36ed2a291de1 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 17:13:38 +0200 Subject: [PATCH 5/8] added unit test for serverInit() --- client_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/client_test.go b/client_test.go index 6f6cec3..1c219ae 100644 --- a/client_test.go +++ b/client_test.go @@ -167,6 +167,56 @@ func TestClientInit(t *testing.T) { } } +func TestServerInit(t *testing.T) { + tests := []struct { + fbWidth, fbHeight uint16 + pixelFormat [16]byte // TODO(kward): replace with PixelFormat + desktopName string + }{ + {100, 200, [16]byte{}, "foo"}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, tt.fbWidth); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, tt.fbHeight); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, tt.pixelFormat); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.desktopName))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.desktopName)); err != nil { + t.Fatal(err) + } + + err := conn.serverInit() + if err != nil { + t.Fatalf("serverInit() error %v", err) + } + if conn.FrameBufferWidth != tt.fbWidth { + t.Errorf("serverInit() FrameBufferWidth: got = %v, want = %v", conn.FrameBufferWidth, tt.fbWidth) + } + if conn.FrameBufferHeight != tt.fbHeight { + t.Errorf("serverInit() FrameBufferHeight: got = %v, want = %v", conn.FrameBufferHeight, tt.fbHeight) + } + // TODO(kward): add test for PixelFormat. + if conn.DesktopName != tt.desktopName { + t.Errorf("serverInit() DesktopName: got = %v, want = %v", conn.DesktopName, tt.desktopName) + } + } +} + // MockConn implements the net.Conn interface. type MockConn struct { b bytes.Buffer @@ -184,6 +234,7 @@ func (m *MockConn) RemoteAddr() net.Addr { return nil } func (m *MockConn) SetDeadline(t time.Time) error { return nil } func (m *MockConn) SetReadDeadline(t time.Time) error { return nil } func (m *MockConn) SetWriteDeadline(t time.Time) error { return nil } + // Implement additional buffer.Buffer functions. func (m *MockConn) Reset() { m.b.Reset() From 306a2489e6efb54ec9f45fbf8778a4fb4aca891a Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 17:37:15 +0200 Subject: [PATCH 6/8] add invalid protocol tests to TestServerInit() --- client_test.go | 53 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/client_test.go b/client_test.go index 1c219ae..62293ea 100644 --- a/client_test.go +++ b/client_test.go @@ -168,12 +168,26 @@ func TestClientInit(t *testing.T) { } func TestServerInit(t *testing.T) { + const ( + none = iota + fbw + fbh + pf + dn + ) tests := []struct { + eof int fbWidth, fbHeight uint16 pixelFormat [16]byte // TODO(kward): replace with PixelFormat desktopName string }{ - {100, 200, [16]byte{}, "foo"}, + // Valid protocol. + {dn, 100, 200, [16]byte{}, "foo"}, + // Invalid protocol (missing fields). + {eof: none}, + {eof: fbw, fbWidth: 1}, + {eof: fbh, fbWidth: 2, fbHeight: 1}, + {eof: pf, fbWidth: 3, fbHeight: 2, pixelFormat: [16]byte{}}, } mockConn := &MockConn{} @@ -184,23 +198,38 @@ func TestServerInit(t *testing.T) { for _, tt := range tests { mockConn.Reset() - if err := binary.Write(conn.c, binary.BigEndian, tt.fbWidth); err != nil { - t.Fatal(err) - } - if err := binary.Write(conn.c, binary.BigEndian, tt.fbHeight); err != nil { - t.Fatal(err) + if tt.eof >= fbw { + if err := binary.Write(conn.c, binary.BigEndian, tt.fbWidth); err != nil { + t.Fatal(err) + } } - if err := binary.Write(conn.c, binary.BigEndian, tt.pixelFormat); err != nil { - t.Fatal(err) + if tt.eof >= fbh { + if err := binary.Write(conn.c, binary.BigEndian, tt.fbHeight); err != nil { + t.Fatal(err) + } } - if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.desktopName))); err != nil { - t.Fatal(err) + if tt.eof >= pf { + if err := binary.Write(conn.c, binary.BigEndian, tt.pixelFormat); err != nil { + t.Fatal(err) + } } - if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.desktopName)); err != nil { - t.Fatal(err) + if tt.eof >= dn { + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.desktopName))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.desktopName)); err != nil { + t.Fatal(err) + } } err := conn.serverInit() + if tt.eof < dn && err == nil { + t.Fatalf("serverInit() expected error") + } + if tt.eof < dn { + // If the protocol was incomplete, there is no point in checking values. + continue + } if err != nil { t.Fatalf("serverInit() error %v", err) } From f190d8e05217c4053181497c1b41badaa0cb11f2 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 22:09:40 +0200 Subject: [PATCH 7/8] added constants for standard security types --- client_auth.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/client_auth.go b/client_auth.go index c88f911..c6dd037 100644 --- a/client_auth.go +++ b/client_auth.go @@ -4,6 +4,12 @@ import ( "net" ) +const ( + secTypeInvalid = iota + secTypeNone + secTypeVNCAuth +) + // A ClientAuth implements a method of authenticating with a remote server. type ClientAuth interface { // SecurityType returns the byte identifier sent by the server to @@ -16,10 +22,10 @@ type ClientAuth interface { } // ClientAuthNone is the "none" authentication. See 7.1.2 -type ClientAuthNone byte +type ClientAuthNone struct{} func (*ClientAuthNone) SecurityType() uint8 { - return 1 + return secTypeNone } func (*ClientAuthNone) Handshake(net.Conn) error { From 2f17980fb9f01043e38e1828fa848e657317303f Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 22:10:18 +0200 Subject: [PATCH 8/8] added tests for protocolVersionHandshake() and securityHandshake() --- client.go | 15 +++--- client_test.go | 137 +++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 134 insertions(+), 18 deletions(-) diff --git a/client.go b/client.go index 9f81973..8352bf0 100644 --- a/client.go +++ b/client.go @@ -87,7 +87,7 @@ func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { conn.Close() return nil, err } - if _, err := conn.clientInit(); err != nil { + if err := conn.clientInit(); err != nil { conn.Close() return nil, err } @@ -342,7 +342,7 @@ func (c *ClientConn) protocolVersionHandshake() error { pv = PROTO_VERS_3_8 } if pv == PROTO_VERS_UNSUP { - return fmt.Errorf("unsupported server ProtocolVersion '%v'", string(protocolVersion[:])) + return NewVNCError(fmt.Sprintf("ProtocolVersion handshake failed; unsupported version '%v'", string(protocolVersion[:]))) } // Respond with the version we will support @@ -356,7 +356,6 @@ func (c *ClientConn) protocolVersionHandshake() error { // securityHandshake implements §7.1.2 Security Handshake. func (c *ClientConn) securityHandshake() error { var numSecurityTypes uint8 - if err := binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil { return err } @@ -365,7 +364,7 @@ func (c *ClientConn) securityHandshake() error { if err != nil { return err } - return fmt.Errorf("no security types: %s", reason) + return NewVNCError(fmt.Sprintf("Security handshake failed; no security types: %v", reason)) } securityTypes := make([]uint8, numSecurityTypes) @@ -390,7 +389,7 @@ FindAuth: } } if auth == nil { - return fmt.Errorf("no suitable auth schemes found. server supported: %#v", securityTypes) + return NewVNCError(fmt.Sprintf("Security handshake failed; no suitable auth schemes found; server supports: %#v", securityTypes)) } // Respond back with the security type we'll use @@ -422,17 +421,17 @@ func (c *ClientConn) securityResultHandshake() error { } // clientInit implements §7.3.1 ClientInit. -func (c *ClientConn) clientInit() (uint8, error) { +func (c *ClientConn) clientInit() error { var sharedFlag uint8 if !c.config.Exclusive { sharedFlag = 1 } if err := binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { - return 0, err + return err } - return sharedFlag, nil + return nil } // serverInit implements §7.3.2 ServerInit. diff --git a/client_test.go b/client_test.go index 62293ea..0faa55a 100644 --- a/client_test.go +++ b/client_test.go @@ -43,9 +43,10 @@ func TestClient_LowMajorVersion(t *testing.T) { if err == nil { t.Fatal("error expected") } - - if err.Error() != "unsupported server ProtocolVersion 'RFB 002.009\n'" { - t.Fatalf("unexpected error: %s", err) + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("Client() unexpected %v error: %v", reflect.TypeOf(err), verr) + } } } @@ -59,9 +60,10 @@ func TestClient_LowMinorVersion(t *testing.T) { if err == nil { t.Fatal("error expected") } - - if err.Error() != "unsupported server ProtocolVersion 'RFB 003.007\n'" { - t.Fatalf("unexpected error: %s", err) + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("Client() unexpected %v error: %v", reflect.TypeOf(err), verr) + } } } @@ -98,6 +100,116 @@ func TestParseProtocolVersion(t *testing.T) { } } +func TestProtocolVersionHandshake(t *testing.T) { + tests := []struct { + server string + client string + ok bool + }{ + // Supported versions. + {"RFB 003.008\n", "RFB 003.008\n", true}, + {"RFB 003.389\n", "RFB 003.008\n", true}, + // Unsupported versions. + {server: "RFB 003.003\n", ok: false}, + {server: "RFB 002.009\n", ok: false}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.server)); err != nil { + t.Fatal(err) + } + + // Validate server message. + err := conn.protocolVersionHandshake() + if err == nil && !tt.ok { + t.Fatalf("protocolVersionHandshake() expected error for server protocol version %v", tt.server) + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("protocolVersionHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + + // Validate client response. + var client [pvLen]byte + err = binary.Read(conn.c, binary.BigEndian, &client) + if err == nil && !tt.ok { + t.Fatalf("protocolVersionHandshake() unexpected error: %v", err) + } + if string(client[:]) != tt.client && tt.ok { + t.Errorf("protocolVersionHandshake() client version: got = %v, want = %v", string(client[:]), tt.client) + } + } +} + +func TestSecurityHandshake(t *testing.T) { + tests := []struct { + server []uint8 + client []ClientAuth + secType uint8 + ok bool + }{ + //-- Supported security types. -- + // Both server and client support the None security type. + {[]uint8{secTypeNone}, []ClientAuth{&ClientAuthNone{}}, secTypeNone, true}, + // Server supports None and VNCAuth, client supports only None. + {[]uint8{secTypeVNCAuth, secTypeNone}, []ClientAuth{&ClientAuthNone{}}, secTypeNone, true}, + //-- Unsupported security types. -- + // Server provided no security types. + {[]uint8{}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false}, + // Client and server don't support same security types. + {[]uint8{secTypeVNCAuth}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if len(tt.server) > 0 { + if err := binary.Write(conn.c, binary.BigEndian, uint8(len(tt.server))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.server)); err != nil { + t.Fatal(err) + } + } + + // Validate server message. + conn.config.Auth = tt.client + err := conn.securityHandshake() + if err == nil && !tt.ok { + t.Fatalf("securityHandshake() expected error for server auth %v", tt.server) + } + if len(tt.server) == 0 { + // The protocol was incomplete; no point in checking values. + continue + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("securityHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + + // Validate client response. + var secType uint8 + err = binary.Read(conn.c, binary.BigEndian, &secType) + if secType != tt.secType { + t.Errorf("securityHandshake() secType: got = %v, want = %v", secType, tt.secType) + } + } +} + func TestSecurityResultHandshake(t *testing.T) { tests := []struct { result uint32 @@ -126,6 +238,7 @@ func TestSecurityResultHandshake(t *testing.T) { t.Fatal(err) } + // Validate server message. err := conn.securityResultHandshake() if err == nil && !tt.ok { t.Fatalf("securityResultHandshake() expected error for result %v", tt.result) @@ -157,12 +270,15 @@ func TestClientInit(t *testing.T) { mockConn.Reset() conn.config.Exclusive = tt.exclusive - shared, err := conn.clientInit() + // Validate client response. + err := conn.clientInit() if err != nil { - t.Fatalf("clientInit() error %v", err) + t.Fatalf("clientInit() unexpected error %v", err) } + var shared uint8 + err = binary.Read(conn.c, binary.BigEndian, &shared) if shared != tt.shared { - t.Errorf("clientInit() got = %v, want %v", shared, tt.shared) + t.Errorf("clientInit() shared: got = %v, want = %v", shared, tt.shared) } } } @@ -222,12 +338,13 @@ func TestServerInit(t *testing.T) { } } + // Validate server message. err := conn.serverInit() if tt.eof < dn && err == nil { t.Fatalf("serverInit() expected error") } if tt.eof < dn { - // If the protocol was incomplete, there is no point in checking values. + // The protocol was incomplete; no point in checking values. continue } if err != nil {