diff --git a/ssh/buffer.go b/ssh/buffer.go index 1ab07d0..803b65a 100644 --- a/ssh/buffer.go +++ b/ssh/buffer.go @@ -5,6 +5,7 @@ package ssh import ( + "fmt" "io" "sync" ) @@ -19,7 +20,8 @@ type buffer struct { head *element // the buffer that will be read first tail *element // the buffer that will be read last - closed bool + closed bool + readAborted bool // set to true if a read was aborted } // An element represents a single link in a linked list. @@ -65,6 +67,7 @@ func (b *buffer) Read(buf []byte) (n int, err error) { b.Cond.L.Lock() defer b.Cond.L.Unlock() + b.readAborted = false for len(buf) > 0 { // if there is data in b.head, copy it if len(b.head.buf) > 0 { @@ -92,6 +95,21 @@ func (b *buffer) Read(buf []byte) (n int, err error) { } // out of buffers, wait for producer b.Cond.Wait() + + // we were signaled because the read was aborted, so + // return 0 bytes + if b.readAborted { + fmt.Printf("ssh.buffer: Read: I was aborted\n") + return + } } return } + +func (b *buffer) abortRead() { + fmt.Printf("ssh.buffer: abortRead: aborting any reads in progress\n") + b.Cond.L.Lock() + b.readAborted = true + b.Cond.Signal() + b.Cond.L.Unlock() +} diff --git a/ssh/channel.go b/ssh/channel.go index cc0bb7a..4fe57f7 100644 --- a/ssh/channel.go +++ b/ssh/channel.go @@ -11,6 +11,7 @@ import ( "io" "log" "sync" + "time" ) const ( @@ -203,6 +204,11 @@ type channel struct { // packetPool has a buffer for each extended channel ID to // save allocations during writes. packetPool map[uint32][]byte + + readDeadline time.Time + readDeadlineTimer *time.Timer + readDeadlineMu sync.Mutex + readDeadlineTimerClose chan struct{} } // writePacket sends a packet. If the packet is a channel close, it updates @@ -356,6 +362,11 @@ func (c *channel) adjustWindow(adj uint32) error { } func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { + // Check read readline + if !c.readDeadline.IsZero() && time.Until(c.readDeadline) <= 0 { + return 0, nil + } + switch extended { case 1: n, err = c.extPending.Read(data) @@ -379,6 +390,60 @@ func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) return n, err } +// SetDeadline sets the read deadline for future Read calls +// and any currently-blocked Read call. +// If there is a read in progress it is aborted +// and the read will return 0 bytes. +// A zero value for t means Read will not time out. +func (c *channel) SetReadDeadline(deadline time.Time) error { + c.readDeadlineMu.Lock() + if c.readDeadlineTimer != nil { + if c.readDeadlineTimer.Stop() { + c.readDeadlineTimer = nil + } else { + // readDeadlineTimer fired, but we acquired the lock before the callback + // handler did. Wait for callback to acquire it before we start a new timer. + c.readDeadlineTimerClose = make(chan struct{}) + c.readDeadlineMu.Unlock() + <-c.readDeadlineTimerClose + c.readDeadlineMu.Lock() + } + } + c.readDeadlineMu.Unlock() + + c.readDeadline = deadline + if c.readDeadline.IsZero() { + return nil + } + + left := time.Until(c.readDeadline) + if left <= 0 { + fmt.Printf("ssh.channel: SetReadDeadline: read deadline is in the past. Aborting any in-progress reads\n") + + c.pending.abortRead() + c.extPending.abortRead() + } + + c.readDeadlineMu.Lock() + c.readDeadlineTimer = time.AfterFunc(left, c.readDeadlineReached) + c.readDeadlineMu.Unlock() + + return nil +} + +func (c *channel) readDeadlineReached() { + c.readDeadlineMu.Lock() + defer c.readDeadlineMu.Unlock() + if c.readDeadlineTimerClose != nil { + close(c.readDeadlineTimerClose) + c.readDeadlineTimer, c.readDeadlineTimerClose = nil, nil + return + } + c.pending.abortRead() + c.extPending.abortRead() + c.readDeadlineTimer = nil +} + func (c *channel) close() { c.pending.eof() c.extPending.eof() diff --git a/ssh/tcpip.go b/ssh/tcpip.go index ef5059a..c851003 100644 --- a/ssh/tcpip.go +++ b/ssh/tcpip.go @@ -486,10 +486,9 @@ func (t *chanConn) RemoteAddr() net.Addr { // SetDeadline sets the read and write deadlines associated // with the connection. func (t *chanConn) SetDeadline(deadline time.Time) error { - if err := t.SetReadDeadline(deadline); err != nil { - return err - } - return t.SetWriteDeadline(deadline) + // for compatibility with previous version, + // the error message contains "tcpChan" + return errors.New("ssh: tcpChan: deadline not supported") } // SetReadDeadline sets the read deadline. @@ -497,11 +496,21 @@ func (t *chanConn) SetDeadline(deadline time.Time) error { // After the deadline, the error from Read will implement net.Error // with Timeout() == true. func (t *chanConn) SetReadDeadline(deadline time.Time) error { + fmt.Printf("ssh.chanConn: SetReadDeadline: checking if Channel has SetReadDeadline\n") + if dl, ok := t.Channel.(setReadDeadliner); ok { + fmt.Printf("ssh.chanConn: SetReadDeadline: yep, it has it\n") + return dl.SetReadDeadline(deadline) + } + // for compatibility with previous version, // the error message contains "tcpChan" return errors.New("ssh: tcpChan: deadline not supported") } +type setReadDeadliner interface { + SetReadDeadline(deadline time.Time) error +} + // SetWriteDeadline exists to satisfy the net.Conn interface // but is not implemented by this type. It always returns an error. func (t *chanConn) SetWriteDeadline(deadline time.Time) error {