Skip to content

Commit 222d4f7

Browse files
feat: add Unix forwarding server implementations
Adds optional (disabled by default) implementations of local->remote and remote->local Unix forwarding through OpenSSH's protocol extensions: - [email protected] - [email protected] - [email protected] - [email protected] Adds tests for Unix forwarding, reverse Unix forwarding and reverse TCP forwarding. Co-authored-by: Samuel Corsi-House <[email protected]>
1 parent adec695 commit 222d4f7

9 files changed

+622
-30
lines changed

options_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func TestPasswordAuth(t *testing.T) {
4949

5050
func TestPasswordAuthBadPass(t *testing.T) {
5151
t.Parallel()
52-
l := newLocalListener()
52+
l := newLocalTCPListener()
5353
srv := &Server{Handler: func(s Session) {}}
5454
srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
5555
return false

server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ type Server struct {
4747
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
4848
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
4949
ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
50+
LocalUnixForwardingCallback LocalUnixForwardingCallback // callback for allowing local unix forwarding ([email protected]), denies all if nil
51+
ReverseUnixForwardingCallback ReverseUnixForwardingCallback // callback for allowing reverse unix forwarding ([email protected]), denies all if nil
5052
ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options
5153
SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions
5254

server_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func TestAddHostKey(t *testing.T) {
2929
}
3030

3131
func TestServerShutdown(t *testing.T) {
32-
l := newLocalListener()
32+
l := newLocalTCPListener()
3333
testBytes := []byte("Hello world\n")
3434
s := &Server{
3535
Handler: func(s Session) {
@@ -80,7 +80,7 @@ func TestServerShutdown(t *testing.T) {
8080
}
8181

8282
func TestServerClose(t *testing.T) {
83-
l := newLocalListener()
83+
l := newLocalTCPListener()
8484
s := &Server{
8585
Handler: func(s Session) {
8686
time.Sleep(5 * time.Second)

session_test.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,25 @@ func (srv *Server) serveOnce(l net.Listener) error {
2020
return e
2121
}
2222
srv.ChannelHandlers = map[string]ChannelHandler{
23-
"session": DefaultSessionHandler,
24-
"direct-tcpip": DirectTCPIPHandler,
23+
"session": DefaultSessionHandler,
24+
"direct-tcpip": DirectTCPIPHandler,
25+
"[email protected]": DirectStreamLocalHandler,
2526
}
27+
28+
forwardedTCPHandler := &ForwardedTCPHandler{}
29+
forwardedUnixHandler := &ForwardedUnixHandler{}
30+
srv.RequestHandlers = map[string]RequestHandler{
31+
"tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
32+
"cancel-tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
33+
"[email protected]": forwardedUnixHandler.HandleSSHRequest,
34+
"[email protected]": forwardedUnixHandler.HandleSSHRequest,
35+
}
36+
2637
srv.HandleConn(conn)
2738
return nil
2839
}
2940

30-
func newLocalListener() net.Listener {
41+
func newLocalTCPListener() net.Listener {
3142
l, err := net.Listen("tcp", "127.0.0.1:0")
3243
if err != nil {
3344
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
@@ -64,7 +75,7 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g
6475
}
6576

6677
func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
67-
l := newLocalListener()
78+
l := newLocalTCPListener()
6879
go srv.serveOnce(l)
6980
return newClientSession(t, l.Addr().String(), cfg)
7081
}

ssh.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ssh
22

33
import (
44
"crypto/subtle"
5+
"errors"
56
"net"
67

78
gossh "golang.org/x/crypto/ssh"
@@ -29,6 +30,9 @@ const (
2930
// DefaultHandler is the default Handler used by Serve.
3031
var DefaultHandler Handler
3132

33+
// ErrReject is returned by some callbacks to reject a request.
34+
var ErrRejected = errors.New("rejected")
35+
3236
// Option is a functional option handler for Server.
3337
type Option func(*Server) error
3438

@@ -64,6 +68,15 @@ type LocalPortForwardingCallback func(ctx Context, destinationHost string, desti
6468
// ReversePortForwardingCallback is a hook for allowing reverse port forwarding
6569
type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool
6670

71+
// LocalUnixForwardingCallback is a hook for allowing unix forwarding
72+
73+
type LocalUnixForwardingCallback func(ctx Context, socketPath string) bool
74+
75+
// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding
76+
// ([email protected]). Returning ErrRejected will reject the
77+
// request.
78+
type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error)
79+
6780
// ServerConfigCallback is a hook for creating custom default server configs
6881
type ServerConfigCallback func(ctx Context) *gossh.ServerConfig
6982

streamlocal.go

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
package ssh
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"io/fs"
8+
"net"
9+
"os"
10+
"path/filepath"
11+
"sync"
12+
"syscall"
13+
14+
gossh "golang.org/x/crypto/ssh"
15+
)
16+
17+
const (
18+
forwardedUnixChannelType = "[email protected]"
19+
)
20+
21+
// directStreamLocalChannelData data struct as specified in OpenSSH's protocol
22+
// extensions document, Section 2.4.
23+
// https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL?annotate=HEAD
24+
type directStreamLocalChannelData struct {
25+
SocketPath string
26+
27+
Reserved1 string
28+
Reserved2 uint32
29+
}
30+
31+
// DirectStreamLocalHandler provides Unix forwarding from client -> server. It
32+
// can be enabled by adding it to the server's ChannelHandlers under
33+
34+
//
35+
// Unix socket support on Windows is not widely available, so this handler may
36+
// not work on all Windows installations and is not tested on Windows.
37+
func DirectStreamLocalHandler(srv *Server, _ *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
38+
var d directStreamLocalChannelData
39+
err := gossh.Unmarshal(newChan.ExtraData(), &d)
40+
if err != nil {
41+
_ = newChan.Reject(gossh.ConnectionFailed, "error parsing direct-streamlocal data: "+err.Error())
42+
return
43+
}
44+
45+
if srv.LocalUnixForwardingCallback == nil || !srv.LocalUnixForwardingCallback(ctx, d.SocketPath) {
46+
newChan.Reject(gossh.Prohibited, "unix forwarding is disabled")
47+
return
48+
}
49+
50+
var dialer net.Dialer
51+
dconn, err := dialer.DialContext(ctx, "unix", d.SocketPath)
52+
if err != nil {
53+
_ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %+v", d.SocketPath, err.Error()))
54+
return
55+
}
56+
57+
ch, reqs, err := newChan.Accept()
58+
if err != nil {
59+
_ = dconn.Close()
60+
return
61+
}
62+
go gossh.DiscardRequests(reqs)
63+
64+
bicopy(ctx, ch, dconn)
65+
}
66+
67+
// remoteUnixForwardRequest describes the extra data sent in a
68+
// [email protected] containing the socket path to bind to.
69+
type remoteUnixForwardRequest struct {
70+
SocketPath string
71+
}
72+
73+
// remoteUnixForwardChannelData describes the data sent as the payload in the new
74+
// channel request when a Unix connection is accepted by the listener.
75+
type remoteUnixForwardChannelData struct {
76+
SocketPath string
77+
Reserved uint32
78+
}
79+
80+
// ForwardedUnixHandler can be enabled by creating a ForwardedUnixHandler and
81+
// adding the HandleSSHRequest callback to the server's RequestHandlers under
82+
83+
84+
//
85+
// Unix socket support on Windows is not widely available, so this handler may
86+
// not work on all Windows installations and is not tested on Windows.
87+
type ForwardedUnixHandler struct {
88+
sync.Mutex
89+
forwards map[string]net.Listener
90+
}
91+
92+
func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
93+
h.Lock()
94+
if h.forwards == nil {
95+
h.forwards = make(map[string]net.Listener)
96+
}
97+
h.Unlock()
98+
conn, ok := ctx.Value(ContextKeyConn).(*gossh.ServerConn)
99+
if !ok {
100+
// TODO: log cast failure
101+
return false, nil
102+
}
103+
104+
switch req.Type {
105+
106+
var reqPayload remoteUnixForwardRequest
107+
err := gossh.Unmarshal(req.Payload, &reqPayload)
108+
if err != nil {
109+
// TODO: log parse failure
110+
return false, nil
111+
}
112+
113+
if srv.ReverseUnixForwardingCallback == nil {
114+
return false, []byte("unix forwarding is disabled")
115+
}
116+
117+
addr := reqPayload.SocketPath
118+
h.Lock()
119+
_, ok := h.forwards[addr]
120+
h.Unlock()
121+
if ok {
122+
// TODO: log failure
123+
return false, nil
124+
}
125+
126+
ln, err := srv.ReverseUnixForwardingCallback(ctx, addr)
127+
if err != nil {
128+
if errors.Is(err, ErrRejected) {
129+
return false, []byte("unix forwarding is disabled")
130+
}
131+
// TODO: log unix listen failure
132+
return false, nil
133+
}
134+
135+
// The listener needs to successfully start before it can be added to
136+
// the map, so we don't have to worry about checking for an existing
137+
// listener as you can't listen on the same socket twice.
138+
//
139+
// This is also what the TCP version of this code does.
140+
h.Lock()
141+
h.forwards[addr] = ln
142+
h.Unlock()
143+
144+
ctx, cancel := context.WithCancel(ctx)
145+
go func() {
146+
<-ctx.Done()
147+
_ = ln.Close()
148+
}()
149+
go func() {
150+
defer cancel()
151+
152+
for {
153+
c, err := ln.Accept()
154+
if err != nil {
155+
// closed below
156+
break
157+
}
158+
payload := gossh.Marshal(&remoteUnixForwardChannelData{
159+
SocketPath: addr,
160+
})
161+
162+
go func() {
163+
ch, reqs, err := conn.OpenChannel(forwardedUnixChannelType, payload)
164+
if err != nil {
165+
_ = c.Close()
166+
return
167+
}
168+
go gossh.DiscardRequests(reqs)
169+
bicopy(ctx, ch, c)
170+
}()
171+
}
172+
173+
h.Lock()
174+
ln2, ok := h.forwards[addr]
175+
if ok && ln2 == ln {
176+
delete(h.forwards, addr)
177+
}
178+
h.Unlock()
179+
_ = ln.Close()
180+
}()
181+
182+
return true, nil
183+
184+
185+
var reqPayload remoteUnixForwardRequest
186+
err := gossh.Unmarshal(req.Payload, &reqPayload)
187+
if err != nil {
188+
// TODO: log parse failure
189+
return false, nil
190+
}
191+
h.Lock()
192+
ln, ok := h.forwards[reqPayload.SocketPath]
193+
h.Unlock()
194+
if ok {
195+
_ = ln.Close()
196+
}
197+
return true, nil
198+
199+
default:
200+
return false, nil
201+
}
202+
}
203+
204+
// unlink removes files and unlike os.Remove, directories are kept.
205+
func unlink(path string) error {
206+
// Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go
207+
// for more details.
208+
for {
209+
err := syscall.Unlink(path)
210+
if !errors.Is(err, syscall.EINTR) {
211+
return err
212+
}
213+
}
214+
}
215+
216+
// SimpleUnixReverseForwardingCallback provides a basic implementation for
217+
// ReverseUnixForwardingCallback. The parent directory will be created (with
218+
// os.MkdirAll), and existing files with the same name will be removed.
219+
func SimpleUnixReverseForwardingCallback(_ Context, socketPath string) (net.Listener, error) {
220+
// Create socket parent dir if not exists.
221+
parentDir := filepath.Dir(socketPath)
222+
err := os.MkdirAll(parentDir, 0700)
223+
if err != nil {
224+
return nil, fmt.Errorf("failed to create parent directory %q for socket %q: %w", parentDir, socketPath, err)
225+
}
226+
227+
// Remove existing socket if it exists. We do not use os.Remove() here
228+
// so that directories are kept. Note that it's possible that we will
229+
// overwrite a regular file here. Both of these behaviors match OpenSSH,
230+
// however, which is why we unlink.
231+
err = unlink(socketPath)
232+
if err != nil && !errors.Is(err, fs.ErrNotExist) {
233+
return nil, fmt.Errorf("failed to remove existing file in socket path %q: %w", socketPath, err)
234+
}
235+
236+
ln, err := net.Listen("unix", socketPath)
237+
if err != nil {
238+
return nil, fmt.Errorf("failed to listen on unix socket %q: %w", socketPath, err)
239+
}
240+
241+
return ln, err
242+
}

0 commit comments

Comments
 (0)