diff --git a/tcpip.go b/tcpip.go index 335fda6..390c985 100644 --- a/tcpip.go +++ b/tcpip.go @@ -108,17 +108,21 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go // TODO: log parse failure return false, []byte{} } - if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { - return false, []byte("port forwarding is disabled") - } addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) ln, err := net.Listen("tcp", addr) if err != nil { // TODO: log listen failure return false, []byte{} + } else { + // addr might not be valid anymore if bind port was 0 + addr = ln.Addr().String() } - _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) + _, destPortStr, _ := net.SplitHostPort(addr) destPort, _ := strconv.Atoi(destPortStr) + if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, uint32(destPort)) { + ln.Close() + return false, []byte("port forwarding is disabled") + } h.Lock() h.forwards[addr] = ln h.Unlock()