diff --git a/ssh/server.go b/ssh/server.go index 98679ba5b6..c97082fe46 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -310,8 +310,8 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) s.sessionID = s.transport.getSessionID() s.algorithms = s.transport.getAlgorithms() - var packet []byte - if packet, err = s.transport.readPacket(); err != nil { + packet, err := s.transport.readPacket() + if err != nil { return nil, err } @@ -546,13 +546,36 @@ userAuthLoop: return nil, &ServerAuthError{Errors: authErrs} } - var userAuthReq userAuthRequestMsg - if packet, err := s.transport.readPacket(); err != nil { + packet, err := s.transport.readPacket() + if err != nil { if err == io.EOF { return nil, &ServerAuthError{Errors: authErrs} } return nil, err - } else if err = Unmarshal(packet, &userAuthReq); err != nil { + } + + // Check if this is a service request (re-authentication) + if len(packet) > 0 && packet[0] == msgServiceRequest { + var serviceRequest serviceRequestMsg + if err = Unmarshal(packet, &serviceRequest); err != nil { + return nil, err + } + + if serviceRequest.Service == serviceUserAuth { + serviceAccept := serviceAcceptMsg{ + Service: serviceUserAuth, + } + if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { + return nil, err + } + continue userAuthLoop + } else { + return nil, fmt.Errorf("ssh: unknown service %q", serviceRequest.Service) + } + } + + var userAuthReq userAuthRequestMsg + if err = Unmarshal(packet, &userAuthReq); err != nil { return nil, err }