diff --git a/go.mod b/go.mod index 63d3684..651a39a 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.12 require ( github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d - golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b // indirect golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 // indirect golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 // indirect ) diff --git a/go.sum b/go.sum index a4132b0..aa2d0eb 100644 --- a/go.sum +++ b/go.sum @@ -1,28 +1,15 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d h1:3qF+Z8Hkrw9sOhrFHti9TlB1Hkac1x+DNRkv0XQiFjo= golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b h1:ZmngSVLe/wycRns9MKikG9OWIEjGcGAkacif7oYQaUY= -golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio= -golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 h1:UiNENfZ8gDvpiWw7IpOMQ27spWmThO1RwwdQVbJahJM= golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 h1:Q5284mrmYTpACcm+eAKjKJH48BBwSyfJqmmGDTtT8Vc= golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/pty.go b/pty.go new file mode 100644 index 0000000..6d54a20 --- /dev/null +++ b/pty.go @@ -0,0 +1,57 @@ +package ssh + +import ( + "bytes" + "io" +) + +// NewPtyWriter creates a writer that handles when the session has a active +// PTY, replacing the \n with \r\n. +func NewPtyWriter(w io.Writer) io.Writer { + return ptyWriter{ + w: w, + } +} + +var _ io.Writer = ptyWriter{} + +type ptyWriter struct { + w io.Writer +} + +func (w ptyWriter) Write(p []byte) (int, error) { + m := len(p) + // normalize \n to \r\n when pty is accepted. + // this is a hardcoded shortcut since we don't support terminal modes. + p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1) + p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1) + n, err := w.w.Write(p) + if n > m { + n = m + } + return n, err +} + +// NewPtyReadWriter return an io.ReadWriter that delegates the read to the +// given io.ReadWriter, and the writes to a ptyWriter. +func NewPtyReadWriter(rw io.ReadWriter) io.ReadWriter { + return readWriterDelegate{ + w: NewPtyWriter(rw), + r: rw, + } +} + +var _ io.ReadWriter = readWriterDelegate{} + +type readWriterDelegate struct { + w io.Writer + r io.Reader +} + +func (rw readWriterDelegate) Read(p []byte) (n int, err error) { + return rw.r.Read(p) +} + +func (rw readWriterDelegate) Write(p []byte) (n int, err error) { + return rw.w.Write(p) +} diff --git a/pty_test.go b/pty_test.go new file mode 100644 index 0000000..11db762 --- /dev/null +++ b/pty_test.go @@ -0,0 +1,24 @@ +package ssh_test + +import ( + "bytes" + "testing" + + "github.com/gliderlabs/ssh" +) + +func TestNewPtyWriter(t *testing.T) { + in := "\nfoo\r\nbar\nmore text\rmore\r\r\r\nfoo\n\n" + out := "\r\nfoo\r\nbar\r\nmore text\rmore\r\r\r\nfoo\r\n\r\n" + var b bytes.Buffer + n, err := ssh.NewPtyWriter(&b).Write([]byte(in)) + if err != nil { + t.Error("did not expect an error", err) + } + if out != b.String() { + t.Errorf("outputs do not match, expected %q got %q", out, b.String()) + } + if n != len(in) { + t.Errorf("expected to write %d bytes, wrote %d", len(in), n) + } +} diff --git a/session.go b/session.go index 3a3ad70..6fa70c0 100644 --- a/session.go +++ b/session.go @@ -1,9 +1,9 @@ package ssh import ( - "bytes" "errors" "fmt" + "io" "net" "sync" @@ -127,18 +127,16 @@ type session struct { breakCh chan<- bool } -func (sess *session) Write(p []byte) (n int, err error) { +func (sess *session) Stderr() io.ReadWriter { if sess.pty != nil { - m := len(p) - // normalize \n to \r\n when pty is accepted. - // this is a hardcoded shortcut since we don't support terminal modes. - p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1) - p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1) - n, err = sess.Channel.Write(p) - if n > m { - n = m - } - return + return NewPtyReadWriter(sess.Channel.Stderr()) + } + return sess.Channel.Stderr() +} + +func (sess *session) Write(p []byte) (int, error) { + if sess.pty != nil { + return NewPtyWriter(sess.Channel).Write(p) } return sess.Channel.Write(p) } @@ -242,7 +240,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { continue } - var payload = struct{ Value string }{} + payload := struct{ Value string }{} gossh.Unmarshal(req.Payload, &payload) sess.rawCmd = payload.Value @@ -267,7 +265,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { continue } - var payload = struct{ Value string }{} + payload := struct{ Value string }{} gossh.Unmarshal(req.Payload, &payload) sess.subsystem = payload.Value diff --git a/session_test.go b/session_test.go index c6ce617..90b2500 100644 --- a/session_test.go +++ b/session_test.go @@ -6,6 +6,7 @@ import ( "io" "net" "testing" + "time" gossh "golang.org/x/crypto/ssh" ) @@ -107,6 +108,30 @@ func TestStderr(t *testing.T) { } } +func TestPtyStderr(t *testing.T) { + t.Parallel() + testBytes := []byte("Hello world\n\r\n") + expectBytes := []byte("Hello world\r\n\r\n") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Stderr().Write(testBytes) + }, + }, nil) + err := session.RequestPty("xterm", 80, 40, nil) + if err != nil { + t.Fatal(err) + } + defer cleanup() + var stderr bytes.Buffer + session.Stderr = &stderr + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stderr.Bytes(), expectBytes) { + t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), expectBytes) + } +} + func TestStdin(t *testing.T) { t.Parallel() testBytes := []byte("Hello world\n") @@ -228,6 +253,34 @@ func TestPty(t *testing.T) { <-done } +func TestPtyWriter(t *testing.T) { + t.Parallel() + term := "xterm" + winWidth := 40 + winHeight := 80 + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + _, _ = fmt.Fprintln(s, "foo\nbar") + time.Sleep(10 * time.Millisecond) + _, _ = fmt.Fprintln(s.Stderr(), "many\nerrors") + _ = s.Exit(0) + }, + }, nil) + defer cleanup() + if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil { + t.Fatalf("expected nil but got %v", err) + } + bts, err := session.CombinedOutput("") + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + + expected := "foo\r\nbar\r\nmany\r\nerrors\r\n" + if expected != string(bts) { + t.Fatalf("expected output to be %q, got %q", expected, string(bts)) + } +} + func TestPtyResize(t *testing.T) { t.Parallel() winch0 := Window{40, 80}