diff --git a/ssh/channel.go b/ssh/channel.go index c0834c00df..5fdc69d7ed 100644 --- a/ssh/channel.go +++ b/ssh/channel.go @@ -212,7 +212,7 @@ func (ch *channel) writePacket(packet []byte) error { return io.EOF } ch.sentClose = (packet[0] == msgChannelClose) - err := ch.mux.conn.writePacket(packet) + err := ch.mux.conn.WritePacket(packet) ch.writeMu.Unlock() return err } diff --git a/ssh/client.go b/ssh/client.go index 7b00bff1ca..dd7fb3ade1 100644 --- a/ssh/client.go +++ b/ssh/client.go @@ -88,6 +88,13 @@ func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil } +func NewClientConnFromTransport(t Transport) (Conn, <-chan NewChannel, <-chan *Request, error) { + conn := &connection{ + mux: newMux(t), + } + return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil +} + // clientHandshake performs the client side key exchange. See RFC 4253 Section // 7. func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { diff --git a/ssh/handshake.go b/ssh/handshake.go index 2b10b05a49..db1d09669f 100644 --- a/ssh/handshake.go +++ b/ssh/handshake.go @@ -183,6 +183,10 @@ func (t *handshakeTransport) printPacket(p []byte, write bool) { } } +func (t *handshakeTransport) ReadPacket() ([]byte, error) { + return t.readPacket() +} + func (t *handshakeTransport) readPacket() ([]byte, error) { p, ok := <-t.incoming if !ok { @@ -479,6 +483,10 @@ func (t *handshakeTransport) sendKexInit() error { return nil } +func (t *handshakeTransport) WritePacket(p []byte) error { + return t.writePacket(p) +} + func (t *handshakeTransport) writePacket(p []byte) error { switch p[0] { case msgKexInit: diff --git a/ssh/mempipe_test.go b/ssh/mempipe_test.go index 8697cd6140..25b6df9f18 100644 --- a/ssh/mempipe_test.go +++ b/ssh/mempipe_test.go @@ -20,6 +20,10 @@ type memTransport struct { *sync.Cond } +func (t *memTransport) ReadPacket() ([]byte, error) { + return t.readPacket() +} + func (t *memTransport) readPacket() ([]byte, error) { t.Lock() defer t.Unlock() @@ -53,6 +57,10 @@ func (t *memTransport) Close() error { return err } +func (t *memTransport) WritePacket(p []byte) error { + return t.writePacket(p) +} + func (t *memTransport) writePacket(p []byte) error { t.write.Lock() defer t.write.Unlock() @@ -66,7 +74,7 @@ func (t *memTransport) writePacket(p []byte) error { return nil } -func memPipe() (a, b packetConn) { +func memPipe() (a, b *memTransport) { t1 := memTransport{} t2 := memTransport{} t1.write = &t2 diff --git a/ssh/mux.go b/ssh/mux.go index f19016270e..ceb70d77aa 100644 --- a/ssh/mux.go +++ b/ssh/mux.go @@ -86,7 +86,7 @@ func (c *chanList) dropAll() []*channel { // mux represents the state for the SSH connection protocol, which // multiplexes many channels onto a single packet transport. type mux struct { - conn packetConn + conn Transport chanList chanList incomingChannels chan NewChannel @@ -113,7 +113,7 @@ func (m *mux) Wait() error { } // newMux returns a mux that runs over the given connection. -func newMux(p packetConn) *mux { +func newMux(p Transport) *mux { m := &mux{ conn: p, incomingChannels: make(chan NewChannel, chanSize), @@ -134,7 +134,7 @@ func (m *mux) sendMessage(msg interface{}) error { if debugMux { log.Printf("send global(%d): %#v", m.chanList.offset, msg) } - return m.conn.writePacket(p) + return m.conn.WritePacket(p) } func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { @@ -212,7 +212,7 @@ func (m *mux) loop() { // onePacket reads and processes one packet. func (m *mux) onePacket() error { - packet, err := m.conn.readPacket() + packet, err := m.conn.ReadPacket() if err != nil { return err } diff --git a/ssh/mux_test.go b/ssh/mux_test.go index 94596ec2d7..7fdda672a0 100644 --- a/ssh/mux_test.go +++ b/ssh/mux_test.go @@ -154,7 +154,7 @@ func TestMuxChannelOverflow(t *testing.T) { marshalUint32(packet[5:], uint32(1)) packet[9] = 42 - if err := writer.mux.conn.writePacket(packet); err != nil { + if err := writer.mux.conn.WritePacket(packet); err != nil { t.Errorf("could not send packet") } if _, err := reader.SendRequest("hello", true, nil); err == nil { @@ -432,7 +432,7 @@ func TestMuxInvalidRecord(t *testing.T) { marshalUint32(packet[5:], 1) packet[9] = 42 - a.conn.writePacket(packet) + a.conn.WritePacket(packet) go a.SendRequest("hello", false, nil) // 'a' wrote an invalid packet, so 'b' has exited. req, ok := <-b.incomingRequests @@ -475,7 +475,7 @@ func TestMuxMaxPacketSize(t *testing.T) { marshalUint32(packet[5:], uint32(len(large))) packet[9] = 42 - if err := a.mux.conn.writePacket(packet); err != nil { + if err := a.mux.conn.WritePacket(packet); err != nil { t.Errorf("could not send packet") } diff --git a/ssh/transport.go b/ssh/transport.go index 49ddc2e7de..47928d4542 100644 --- a/ssh/transport.go +++ b/ssh/transport.go @@ -37,6 +37,21 @@ type packetConn interface { Close() error } +// Transport represents a connection that implements packet based operations as +// specified by SSH Transport Protocol (RFC 4253). +type Transport interface { + // WritePacket encrypts and sends a packet of data to the remote peer. + WritePacket([]byte) error + + // ReadPacket reads and decrypts a packet of data from the remote peer. The + // read is blocking. If error is nil then the returned byte slice is always + // non-empty. + ReadPacket() ([]byte, error) + + // Close closes the connection with the remote peer. + Close() error +} + // transport is the keyingTransport that implements the SSH packet // protocol. type transport struct {